refactor(connlib): delay initialization of Sockets until we have a tokio runtime (#4286)

Our sockets need to be initialized within a tokio runtime context. To
achieve this, we don't actually initialize anything on `Sockets::new`.
Instead, we call `rebind` within the constructor of `Tunnel` which
already runs in a tokio context.

Fixes: #4282

---------

Signed-off-by: Jamil <jamilbk@users.noreply.github.com>
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
Co-authored-by: Reactor Scram <ReactorScram@users.noreply.github.com>
This commit is contained in:
Jamil
2024-03-25 15:51:35 -07:00
committed by GitHub
parent cfc1fb0488
commit 228389882e
10 changed files with 110 additions and 85 deletions

View File

@@ -388,28 +388,27 @@ fn connect(
.enable_all()
.build()?;
let sockets = Sockets::new()?;
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
callback_handler.protect_file_descriptor(ip4_socket)?;
}
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
callback_handler.protect_file_descriptor(ip6_socket)?;
}
let sockets = Sockets::with_protect({
let callbacks = callback_handler.clone();
move |fd| {
callbacks
.protect_file_descriptor(fd)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
});
let session = Session::connect(
login,
sockets,
private_key,
Some(os_version),
callback_handler.clone(),
callback_handler,
Some(MAX_PARTITION_TIME),
runtime.handle().clone(),
)?;
Ok(SessionWrapper {
inner: session,
callbacks: callback_handler,
runtime,
})
}
@@ -463,29 +462,11 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co
pub struct SessionWrapper {
inner: Session,
callbacks: CallbackHandler,
#[allow(dead_code)] // Only here so we don't drop the memory early.
runtime: Runtime,
}
impl SessionWrapper {
fn reconnect(&self) -> Result<(), CallbackError> {
let sockets = Sockets::new()?;
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
self.callbacks.protect_file_descriptor(ip4_socket)?;
}
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
self.callbacks.protect_file_descriptor(ip6_socket)?;
}
self.inner.reconnect(sockets);
Ok(())
}
}
/// # Safety
/// Pointers must be valid
#[allow(non_snake_case)]
@@ -528,11 +509,9 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_reconnect(
mut env: JNIEnv,
_: JNIEnv,
_: JClass,
session: *const SessionWrapper,
) {
if let Err(e) = (*session).reconnect() {
throw(&mut env, "java/lang/Exception", e.to_string());
}
(*session).inner.reconnect();
}

View File

@@ -205,7 +205,7 @@ impl WrappedSession {
let session = Session::connect(
login,
Sockets::new().map_err(|err| err.to_string())?,
Sockets::new(),
private_key,
os_version_override,
CallbackHandler {

View File

@@ -10,7 +10,7 @@ use connlib_shared::{
messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId},
Callbacks,
};
use firezone_tunnel::{ClientTunnel, Sockets};
use firezone_tunnel::ClientTunnel;
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
@@ -37,7 +37,7 @@ pub struct Eventloop<C: Callbacks> {
/// Commands that can be sent to the [`Eventloop`].
pub enum Command {
Stop,
Reconnect(Sockets),
Reconnect,
SetDns(Vec<IpAddr>),
}
@@ -71,9 +71,11 @@ where
tracing::warn!("Failed to update DNS: {e}");
}
}
Poll::Ready(Some(Command::Reconnect(sockets))) => {
Poll::Ready(Some(Command::Reconnect)) => {
self.portal.reconnect();
self.tunnel.reconnect(sockets);
if let Err(e) = self.tunnel.reconnect() {
tracing::warn!("Failed to reconnect tunnel: {e}");
}
continue;
}

View File

@@ -67,11 +67,11 @@ impl Session {
///
/// - Close and re-open a connection to the portal.
/// - Refresh all allocations
/// - Replace the currently used [`Sockets`] with the provided one
/// - Rebind local UDP sockets
///
/// # Implementation note
///
/// The reason we replace [`Sockets`] are:
/// The reason we rebind the UDP sockets are:
///
/// 1. On MacOS, as socket bound to the unspecified IP cannot send to interfaces attached after the socket has been created.
/// 2. Switching between networks changes the 3-tuple of the client.
@@ -80,9 +80,9 @@ impl Session {
/// Changing the IP would be enough for that.
/// However, if the user would now change _back_ to the previous network,
/// the TURN server would recognise the old allocation but the client already lost all its state associated with it.
/// To avoid race-conditions like this, we initialize a new [`Sockets`] instance which allocates a new port.
pub fn reconnect(&self, sockets: Sockets) {
let _ = self.channel.send(Command::Reconnect(sockets));
/// To avoid race-conditions like this, we rebind the sockets to a new port.
pub fn reconnect(&self) {
let _ = self.channel.send(Command::Reconnect);
}
/// Sets a new set of upstream DNS servers for this [`Session`].
@@ -118,7 +118,7 @@ async fn connect<CB>(
where
CB: Callbacks + 'static,
{
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone());
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone())?;
let portal = PhoenixChannel::connect(
Secret::new(url),

View File

@@ -43,8 +43,13 @@ pub enum Input<'a, I> {
}
impl Io {
pub fn new(sockets: Sockets) -> Self {
Self {
/// Creates a new I/O abstraction
///
/// Must be called within a Tokio runtime context so we can bind the sockets.
pub fn new(mut sockets: Sockets) -> io::Result<Self> {
sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context.
Ok(Self {
device: Device::new(),
timeout: None,
sockets,
@@ -53,7 +58,7 @@ impl Io {
Duration::from_secs(60),
DNS_QUERIES_QUEUE_SIZE,
),
}
})
}
pub fn poll<'b>(
@@ -115,8 +120,8 @@ impl Io {
&self.sockets
}
pub(crate) fn set_sockets(&mut self, sockets: Sockets) {
self.sockets = sockets;
pub fn sockets_mut(&mut self) -> &mut Sockets {
&mut self.sockets
}
pub fn set_upstream_dns_servers(

View File

@@ -59,21 +59,27 @@ impl<CB> ClientTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
Self {
io: Io::new(sockets),
pub fn new(
private_key: StaticSecret,
sockets: Sockets,
callbacks: CB,
) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(sockets)?,
callbacks,
role_state: ClientState::new(private_key),
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
}
})
}
pub fn reconnect(&mut self, sockets: Sockets) {
pub fn reconnect(&mut self) -> std::io::Result<()> {
self.role_state.reconnect(Instant::now());
self.io.set_sockets(sockets);
self.io.sockets_mut().rebind()?;
Ok(())
}
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<ClientEvent>> {
@@ -149,16 +155,20 @@ impl<CB> GatewayTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
Self {
io: Io::new(sockets),
pub fn new(
private_key: StaticSecret,
sockets: Sockets,
callbacks: CB,
) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(sockets)?,
callbacks,
role_state: GatewayState::new(private_key),
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
}
})
}
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<GatewayEvent>> {

View File

@@ -13,10 +13,47 @@ use crate::Result;
pub struct Sockets {
socket_v4: Option<Socket>,
socket_v6: Option<Socket>,
#[cfg(unix)]
protect: Box<dyn Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static>,
}
impl Default for Sockets {
fn default() -> Self {
Self::new()
}
}
impl Sockets {
pub fn new() -> io::Result<Self> {
#[cfg(unix)]
pub fn with_protect(
protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static,
) -> Self {
Self {
socket_v4: None,
socket_v6: None,
#[cfg(unix)]
protect: Box::new(protect),
}
}
pub fn new() -> Self {
Self {
socket_v4: None,
socket_v6: None,
#[cfg(unix)]
protect: Box::new(|_| Ok(())),
}
}
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
match addr {
SocketAddr::V4(_) => self.socket_v4.is_some(),
SocketAddr::V6(_) => self.socket_v6.is_some(),
}
}
pub fn rebind(&mut self) -> io::Result<()> {
let socket_v4 = Socket::ip4();
let socket_v6 = Socket::ip6();
@@ -39,31 +76,23 @@ impl Sockets {
_ => (),
}
Ok(Self {
socket_v4: socket_v4.ok(),
socket_v6: socket_v6.ok(),
})
}
#[cfg(unix)]
{
use std::os::fd::AsRawFd;
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
match addr {
SocketAddr::V4(_) => self.socket_v4.is_some(),
SocketAddr::V6(_) => self.socket_v6.is_some(),
if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) {
(self.protect)(fd)?;
}
if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) {
(self.protect)(fd)?;
}
}
}
#[cfg(unix)]
pub fn ip4_socket_fd(&self) -> Option<std::os::fd::RawFd> {
use std::os::fd::AsRawFd;
self.socket_v4 = socket_v4.ok();
self.socket_v6 = socket_v6.ok();
self.socket_v4.as_ref().map(|s| s.socket.as_raw_fd())
}
#[cfg(unix)]
pub fn ip6_socket_fd(&self) -> Option<std::os::fd::RawFd> {
use std::os::fd::AsRawFd;
self.socket_v6.as_ref().map(|s| s.socket.as_raw_fd())
Ok(())
}
/// Flushes all buffered data on the sockets.

View File

@@ -91,7 +91,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
}
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new()?, CallbackHandler);
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?;
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
Secret::new(login),

View File

@@ -536,7 +536,7 @@ impl Controller {
)?;
let connlib = connlib_client_shared::Session::connect(
login,
Sockets::new()?,
Sockets::new(),
private_key,
None,
callback_handler.clone(),
@@ -837,7 +837,7 @@ async fn run_controller(
have_internet = new_have_internet;
if let Some(session) = controller.session.as_mut() {
tracing::debug!("Internet up/down changed, calling `Session::reconnect`");
session.connlib.reconnect(Sockets::new()?);
session.connlib.reconnect();
}
}
},

View File

@@ -38,7 +38,7 @@ async fn main() -> Result<()> {
let session = Session::connect(
login,
Sockets::new()?,
Sockets::new(),
private_key,
None,
callbacks.clone(),
@@ -62,7 +62,7 @@ async fn main() -> Result<()> {
if sighup.poll_recv(cx).is_ready() {
tracing::debug!("Received SIGHUP");
session.reconnect(Sockets::new()?);
session.reconnect();
continue;
}