diff --git a/.codespellrc b/.codespellrc index 52e89eb8c..a2c339d12 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] skip = ./rust/target,Cargo.lock,./www/docs/reference/api/*.mdx,./erl_crash.dump,./apps/*/erl_crash.dump,./cover,./vendor,*.json,yarn.lock,seeds.exs,./**/node_modules,./deps,./priv/static,./priv/plts,./**/priv/static,./.git,./www/build,./_build -ignore-words-list = crate,keypair,keypairs,iif,statics,wee,anull +ignore-words-list = crate,keypair,keypairs,iif,statics,wee,anull,commitish diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5731f1d6a..f3fd73f57 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -18,6 +18,8 @@ jobs: tag_name: ${{ steps.release_drafter.outputs.tag_name }} steps: - uses: release-drafter/release-drafter@v5 + with: + commitish: cloud id: release_drafter env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -43,7 +45,7 @@ jobs: env: RUSTDOCFLAGS: "-D warnings" - run: cargo clippy -p relay --all-targets --all-features -- -D warnings - - run: cargo test + - run: cargo test -p relay test-connlib: needs: draft-release @@ -67,11 +69,30 @@ jobs: - name: Update toolchain run: rustup show - uses: Swatinem/rust-cache@v2 + with: + workspaces: ./rust + # TODO: Building *ring* from git requires us to install additional tools; + # once we're not using a forked *ring* these 2 steps can be removed. + - if: ${{ contains(matrix.runs-on, 'windows') }} + name: Install *ring* build tools + run: | + git clone ` + --branch windows ` + --depth 1 ` + https://github.com/briansmith/ring-toolchain ` + target/tools/windows + # The repo above is for a newer version of the *ring* build script which + # expects different paths; instead of going through the trouble of + # copying the older installation script let's just move the exe. + - if: ${{ contains(matrix.runs-on, 'windows') }} + name: Move *ring* build tools + run: | + mv target/tools/windows/nasm/nasm.exe target/tools/nasm.exe - name: Run connlib checks and tests run: | - cargo check --workspace --exclude relay - cargo clippy --workspace --exclude relay -- -D clippy::all - cargo test --workspace --exclude relay + cargo check --workspace --exclude connlib-apple --exclude relay + cargo clippy --workspace --exclude connlib-apple --exclude relay -- -D clippy::all + cargo test --workspace --exclude connlib-apple --exclude relay build-android: needs: @@ -86,6 +107,8 @@ jobs: steps: - uses: actions/checkout@v3 - uses: Swatinem/rust-cache@v2 + with: + workspaces: ./rust - name: Update toolchain run: rustup show - uses: actions/cache@v3 @@ -129,8 +152,16 @@ jobs: steps: - uses: actions/checkout@v3 - uses: Swatinem/rust-cache@v2 + with: + workspaces: ./rust - name: Update toolchain run: rustup show + - name: Run connlib checks and tests + working-directory: ./rust/connlib/clients/apple + run: | + cargo check -p connlib-apple + cargo clippy -p connlib-apple -- -D clippy::all + cargo test -p connlib-apple - name: Setup lipo run: cargo install cargo-lipo - uses: actions/cache@v3 @@ -139,7 +170,6 @@ jobs: key: ${{ runner.os }}-spm-${{ hashFiles('**/Package.resolved') }} restore-keys: | ${{ runner.os }}-spm- - - name: Build Connlib.xcframework.zip env: CONFIGURATION: Release @@ -151,8 +181,8 @@ jobs: # build first. See https://github.com/briansmith/ring/issues/1332 ./build-rust.sh ./build-xcframework.sh - mv Connlib.xcframework.zip ../../../Connlib-${{ needs.draft-release.outputs.tag_name }}.xcframework.zip - mv Connlib.xcframework.zip.checksum.txt ../../../Connlib-${{ needs.draft-release.outputs.tag_name }}.xcframework.zip.checksum.txt + mv Connlib.xcframework.zip ../../../../Connlib-${{ needs.draft-release.outputs.tag_name }}.xcframework.zip + mv Connlib.xcframework.zip.checksum.txt ../../../../Connlib-${{ needs.draft-release.outputs.tag_name }}.xcframework.zip.checksum.txt - uses: actions/upload-artifact@v3 with: name: connlib-apple diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 007975969..e4af73ce7 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -435,6 +435,11 @@ version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +[[package]] +name = "cc" +version = "1.0.79" +source = "git+https://github.com/youknowone/cc-rs?rev=4ca92100c25ac2df679f0cce11c4c3e830f2e455#4ca92100c25ac2df679f0cce11c4c3e830f2e455" + [[package]] name = "ccm" version = "0.3.0" @@ -593,10 +598,13 @@ dependencies = [ name = "connlib-apple" version = "0.1.6" dependencies = [ + "anyhow", + "diva", "firezone-client-connlib", "libc", "swift-bridge", "swift-bridge-build 0.1.51 (git+https://github.com/conectado/swift-bridge.git?branch=fix-already-declared)", + "walkdir", ] [[package]] @@ -880,6 +888,17 @@ dependencies = [ "syn 2.0.18", ] +[[package]] +name = "diva" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4962b19d77f25a52081b27cd0404773376bda79b395801dcb771646206a20b06" +dependencies = [ + "libc", + "log", + "windows-sys 0.48.0", +] + [[package]] name = "ecdsa" version = "0.14.8" @@ -950,7 +969,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" dependencies = [ - "cc", + "cc 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)", "libc", ] @@ -1529,7 +1548,6 @@ dependencies = [ "futures", "futures-util", "ip_network", - "macros", "os_info", "rand_core 0.6.4", "rtnetlink", @@ -1567,15 +1585,6 @@ version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" -[[package]] -name = "macros" -version = "0.1.0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.18", -] - [[package]] name = "matchers" version = "0.1.0" @@ -2216,10 +2225,9 @@ dependencies = [ [[package]] name = "ring" version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +source = "git+https://github.com/firezone/ring?branch=v0.16.20-cc-fix#ed10dc23f020bd0d3e2af30a40464fad502aeeda" dependencies = [ - "cc", + "cc 1.0.79 (git+https://github.com/youknowone/cc-rs?rev=4ca92100c25ac2df679f0cce11c4c3e830f2e455)", "libc", "once_cell", "spin", @@ -3557,7 +3565,7 @@ dependencies = [ "async-trait", "bitflags", "bytes", - "cc", + "cc 1.0.79 (registry+https://github.com/rust-lang/crates.io-index)", "ipnet", "lazy_static", "libc", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 7eced66ef..6c2705637 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -10,9 +10,13 @@ members = [ "connlib/libs/gateway", "connlib/libs/common", "connlib/gateway", - "connlib/macros", ] [workspace.dependencies] boringtun = { git = "https://github.com/cloudflare/boringtun", rev = "878385f", default-features = false } swift-bridge = { git = "https://github.com/chinedufn/swift-bridge.git", rev = "4fbd30f" } + +# Patched to use https://github.com/rust-lang/cc-rs/pull/708 +# (the `patch` section can't be used for build deps...) +[patch.crates-io] +ring = { git = "https://github.com/firezone/ring", branch = "v0.16.20-cc-fix" } diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 1a80f8edf..f4065e1e9 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -1,43 +1,48 @@ -#[macro_use] -extern crate log; -extern crate android_logger; -extern crate jni; -use self::jni::JNIEnv; -use android_logger::Config; +// The "system" ABI is only needed for Java FFI on Win32, not Android: +// https://github.com/jni-rs/jni-rs/pull/22 +// However, this consideration has made it idiomatic for Java FFI in the Rust +// ecosystem, so it's used here for consistency. + use firezone_client_connlib::{ Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, }; -use jni::objects::{JClass, JObject, JString, JValue}; -use log::LevelFilter; +use jni::{ + objects::{JClass, JObject, JString, JValue}, + JNIEnv, +}; /// This should be called once after the library is loaded by the system. #[allow(non_snake_case)] #[no_mangle] pub extern "system" fn Java_dev_firezone_connlib_Logger_init(_: JNIEnv, _: JClass) { - #[cfg(debug_assertions)] - let level = LevelFilter::Trace; - #[cfg(not(debug_assertions))] - let level = LevelFilter::Warn; - android_logger::init_once( - Config::default() - // Allow all log levels - .with_max_level(level) + android_logger::Config::default() + .with_max_level(if cfg!(debug_assertions) { + log::LevelFilter::Trace + } else { + log::LevelFilter::Warn + }) .with_tag("connlib"), ) } -pub enum CallbackHandler {} +#[derive(Clone)] +pub struct CallbackHandler; + impl Callbacks for CallbackHandler { - fn on_update_resources(_resource_list: ResourceList) { + fn on_update_resources(&self, _resource_list: ResourceList) { todo!() } - fn on_set_tunnel_adresses(_tunnel_addresses: TunnelAddresses) { + fn on_connect(&self, _tunnel_addresses: TunnelAddresses) { todo!() } - fn on_error(_error: &Error, _error_type: ErrorType) { + fn on_disconnect(&self) { + todo!() + } + + fn on_error(&self, _error: &Error, _error_type: ErrorType) { todo!() } } @@ -57,7 +62,7 @@ pub unsafe extern "system" fn Java_dev_firezone_connlib_Session_connect( let portal_token: String = env.get_string(&portal_token).unwrap().into(); let session = Box::new( - Session::connect::(portal_url.as_str(), portal_token).expect("TODO!"), + Session::connect(portal_url.as_str(), portal_token, CallbackHandler).expect("TODO!"), ); // TODO: Get actual IPs returned from portal based on this device @@ -65,12 +70,12 @@ pub unsafe extern "system" fn Java_dev_firezone_connlib_Session_connect( let tunnel_addresses = env.new_string(tunnelAddressesJSON).unwrap(); match env.call_method( callback, - "onSetTunnelAddresses", + "onConnect", "(Ljava/lang/String;)Z", &[JValue::from(&tunnel_addresses)], ) { - Ok(res) => trace!("onSetTunnelAddresses returned {:?}", res), - Err(e) => error!("Failed to call setTunnelAddresses: {:?}", e), + Ok(res) => log::trace!("`onConnect` returned `{res:?}`"), + Err(err) => log::error!("Failed to call `onConnect`: {err}"), } Box::into_raw(session) diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 19ebac6d2..de9bd3cd0 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -4,7 +4,10 @@ version = "0.1.6" edition = "2021" [build-dependencies] +anyhow = "1.0.71" +diva = "0.1.0" swift-bridge-build = { git = "https://github.com/conectado/swift-bridge.git", branch = "fix-already-declared" } +walkdir = "2.3.3" [dependencies] libc = "0.2" diff --git a/rust/connlib/clients/apple/Sources/Connlib/Adapter.swift b/rust/connlib/clients/apple/Sources/Connlib/Adapter.swift index 4b8a36ccb..1c94aba17 100644 --- a/rust/connlib/clients/apple/Sources/Connlib/Adapter.swift +++ b/rust/connlib/clients/apple/Sources/Connlib/Adapter.swift @@ -92,13 +92,8 @@ public class Adapter { do { try self.setNetworkSettings(self.generateNetworkSettings(ipv4Routes: [], ipv6Routes: [])) - self.state = .started( - WrappedSession.connect( - "http://localhost:4568", - "test-token", - Self.callbackHandler! - ) + try WrappedSession.connect("http://localhost:4568", "test-token", Self.callbackHandler!) ) self.networkMonitor = networkMonitor completionHandler(nil) diff --git a/rust/connlib/clients/apple/Sources/Connlib/BridgingHeader-SwiftPM.h b/rust/connlib/clients/apple/Sources/Connlib/BridgingHeader-SwiftPM.h new file mode 100644 index 000000000..137094540 --- /dev/null +++ b/rust/connlib/clients/apple/Sources/Connlib/BridgingHeader-SwiftPM.h @@ -0,0 +1,17 @@ +// This header is used in `build.rs`, and is exactly the same as +// `BridgingHeader.h` *except* the `include` paths are relative. +// +// Attempting to build an `xcframework` with a quoted `include` violates rules +// around non-modular imports, as only headers specified as part of the module +// can be included. +// +// However, SwiftPM has no equivalent to "modular headers", so we can only rely +// on normal, simple `include` paths. + +#ifndef BridgingHeader_h +#define BridgingHeader_h + +#include "Generated/SwiftBridgeCore.h" +#include "Generated/connlib-apple/connlib-apple.h" + +#endif diff --git a/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift b/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift index 37c9cbec1..42b3360fd 100644 --- a/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift +++ b/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift @@ -8,6 +8,10 @@ import NetworkExtension import os.log +// TODO: https://github.com/chinedufn/swift-bridge/issues/150 +extension SwiftConnlibError: @unchecked Sendable {} +extension SwiftConnlibError: Error {} + public protocol CallbackHandlerDelegate: AnyObject { func didUpdateResources(_ resourceList: ResourceList) } @@ -45,7 +49,7 @@ public class CallbackHandler { ) } - func onSetTunnelAddresses(tunnelAddresses: TunnelAddresses) -> Bool { + func onConnect(tunnelAddresses: TunnelAddresses) -> Bool { let addresses4 = [tunnelAddresses.address4.toString()] let addresses6 = [tunnelAddresses.address6.toString()] let ipv4Routes = @@ -58,6 +62,10 @@ public class CallbackHandler { ) } + func onDisconnect() { + // TODO: handle disconnect + } + private func setTunnelSettingsKeepingSomeExisting( addresses4: [String], addresses6: [String], ipv4Routes: [NEIPv4Route], ipv6Routes: [NEIPv6Route] ) -> Bool { @@ -90,4 +98,10 @@ public class CallbackHandler { return false } } + + func onError(error: SwiftConnlibError, error_type: SwiftErrorType) { + // TODO: handle/report errors + let logger = Logger(subsystem: "dev.firezone.firezone", category: "packet-tunnel") + logger.log(level: .error, "Internal connlib error: \(String(describing: error), privacy: .public)") + } } diff --git a/rust/connlib/clients/apple/build-rust.sh b/rust/connlib/clients/apple/build-rust.sh index f97b3ce13..791e261b3 100755 --- a/rust/connlib/clients/apple/build-rust.sh +++ b/rust/connlib/clients/apple/build-rust.sh @@ -24,7 +24,9 @@ base_dir=$(xcrun --sdk $PLATFORM_NAME --show-sdk-path) # See https://github.com/briansmith/ring/issues/1332 export LIBRARY_PATH="${base_dir}/usr/lib" export INCLUDE_PATH="${base_dir}/usr/include" -export CFLAGS="-L ${LIBRARY_PATH} -I ${INCLUDE_PATH}" +# `-Qunused-arguments` stops clang from failing while building *ring* +# (but the library search path is still necessary when building the framework!) +export CFLAGS="-L ${LIBRARY_PATH} -I ${INCLUDE_PATH} -Qunused-arguments" export RUSTFLAGS="-C link-arg=-F$base_dir/System/Library/Frameworks" TARGETS="" diff --git a/rust/connlib/clients/apple/build-xcframework-dev.sh b/rust/connlib/clients/apple/build-xcframework-dev.sh new file mode 100755 index 000000000..324b12322 --- /dev/null +++ b/rust/connlib/clients/apple/build-xcframework-dev.sh @@ -0,0 +1,29 @@ +# For more info: +# https://github.com/firezone/firezone-apple/blob/main/USING_UNRELEASED_CONNLIB.md + +#!/bin/bash +set -ex + +echo $SRC_ROOT + +for sdk in macosx iphoneos; do + echo "Building for $sdk" + + xcodebuild archive \ + -scheme Connlib \ + -destination "generic/platform=$sdk" \ + -sdk $sdk \ + -archivePath ./connlib-$sdk \ + SKIP_INSTALL=NO \ + BUILD_LIBRARY_FOR_DISTRIBUTION=YES +done + +rm -rf ./Connlib.xcframework +xcodebuild -create-xcframework \ + -framework ./connlib-iphoneos.xcarchive/Products/Library/Frameworks/connlib.framework \ + -framework ./connlib-macosx.xcarchive/Products/Library/Frameworks/connlib.framework \ + -output ./Connlib.xcframework + +echo "Build successful. Removing temporary archives" +rm -rf ./connlib-iphoneos.xcarchive +rm -rf ./connlib-macosx.xcarchive diff --git a/rust/connlib/clients/apple/build.rs b/rust/connlib/clients/apple/build.rs index ddc9a9ccb..2989be079 100644 --- a/rust/connlib/clients/apple/build.rs +++ b/rust/connlib/clients/apple/build.rs @@ -1,14 +1,140 @@ -const XCODE_CONFIGURATION_ENV: &str = "CONFIGURATION"; +// Referenced from https://github.com/chinedufn/swift-bridge/blob/master/examples/rust-binary-calls-swift-package/build.rs -fn main() { - let out_dir = "Sources/Connlib/Generated"; +use std::path::PathBuf; +use walkdir::WalkDir; - let bridges = vec!["src/lib.rs"]; - for path in &bridges { - println!("cargo:rerun-if-changed={}", path); - } - println!("cargo:rerun-if-env-changed={}", XCODE_CONFIGURATION_ENV); +static XCODE_CONFIGURATION_ENV: &str = "CONFIGURATION"; +static SWIFT_PKG_NAME: &str = "Connlib"; +static SWIFT_LIB_NAME: &str = "libConnlib.a"; +static BRIDGE_SRCS: &[&str] = &["src/lib.rs"]; +static BRIDGING_HEADER: &str = "BridgingHeader-SwiftPM.h"; +static MACOSX_DEPLOYMENT_TARGET: &str = "12.4"; +static IPHONEOS_DEPLOYMENT_TARGET: &str = "15.6"; - swift_bridge_build::parse_bridges(bridges) - .write_all_concatenated(out_dir, env!("CARGO_PKG_NAME")); +mod sdk { + pub static MACOS: &str = "macosx"; + pub static IOS: &str = "iphoneos"; + pub static IOS_SIM: &str = "iphonesimulator"; +} + +struct Env { + swift_pkg_dir: PathBuf, + swift_src_dir: PathBuf, + bridge_dst_dir: PathBuf, + swift_built_lib_dir: PathBuf, + release: bool, + triple: String, + sdk: &'static str, +} + +impl Env { + fn gather() -> Self { + let manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + let swift_pkg_dir = manifest_dir; + let swift_src_dir = swift_pkg_dir.join("Sources").join(SWIFT_PKG_NAME); + let bridge_dst_dir = swift_src_dir.join("Generated"); + let release = std::env::var("PROFILE").unwrap() == "release"; + let target = std::env::var("TARGET").unwrap(); + let (triple, sdk) = match target.as_str() { + "aarch64-apple-darwin" => ( + format!("arm64-apple-macosx{MACOSX_DEPLOYMENT_TARGET}"), + sdk::MACOS, + ), + "x86_64-apple-darwin" => ( + format!("x86_64-apple-macosx{MACOSX_DEPLOYMENT_TARGET}"), + sdk::MACOS, + ), + "aarch64-apple-ios" => ( + format!("arm64-apple-ios{IPHONEOS_DEPLOYMENT_TARGET}"), + sdk::IOS, + ), + "aarch64-apple-ios-sim" => ( + format!("arm64-apple-ios{IPHONEOS_DEPLOYMENT_TARGET}-simulator"), + sdk::IOS_SIM, + ), + "x86_64-apple-ios" | "x86_64-apple-ios-sim" => ( + format!("x86_64-apple-ios{IPHONEOS_DEPLOYMENT_TARGET}-simulator"), + sdk::IOS_SIM, + ), + _ => todo!("unsupported target triple: {target:?}"), + }; + let swift_built_lib_dir = swift_pkg_dir.join(".build").join(&triple).join(if release { + "release" + } else { + "debug" + }); + Self { + swift_pkg_dir, + swift_src_dir, + bridge_dst_dir, + swift_built_lib_dir, + release, + triple, + sdk, + } + } +} + +fn gen_bridges(env: &Env) { + for path in BRIDGE_SRCS { + println!("cargo:rerun-if-changed={path}"); + } + swift_bridge_build::parse_bridges(BRIDGE_SRCS) + .write_all_concatenated(&env.bridge_dst_dir, env!("CARGO_PKG_NAME")); +} + +// We use `swiftc` instead of SwiftPM/`swift build` because of this limitation: +// https://github.com/apple/swift-package-manager/pull/6572 +fn compile_swift(env: &Env) -> anyhow::Result<()> { + let swift_sdk = diva::Command::parse("xcrun --show-sdk-path --sdk") + .with_arg(env.sdk) + .run_and_wait_for_trimmed()?; + let swift_src_files = WalkDir::new(&env.swift_src_dir) + .into_iter() + .filter_map(Result::ok) + .filter_map(|entry| { + (entry.path().extension() == Some("swift".as_ref())).then(|| entry.path().to_owned()) + }); + std::fs::create_dir_all(&env.swift_built_lib_dir)?; + diva::Command::parse("swiftc -emit-library -static") + .with_args(["-module-name", SWIFT_PKG_NAME]) + .with_arg("-import-objc-header") + .with_arg(env.swift_src_dir.join(BRIDGING_HEADER)) + .with_arg("-sdk") + .with_arg(swift_sdk) + .with_args(["-target", &env.triple]) + // https://github.com/apple/swift-package-manager/blob/55006dce81ae70cd8f2b78479038423eeebde1e4/Documentation/Usage.md#setting-the-build-configuration + .with_parsed_args(if !env.release { + "-Onone -g -enable-testing" + } else { + "-O -whole-module-optimization" + }) + .with_arg("-o") + .with_arg(env.swift_built_lib_dir.join(SWIFT_LIB_NAME)) + .with_args(swift_src_files) + .with_cwd(&env.swift_pkg_dir) + .run_and_wait()?; + Ok(()) +} + +fn link_swift(env: &Env) { + println!("cargo:rustc-link-lib=static={SWIFT_PKG_NAME}"); + println!( + "cargo:rustc-link-search={}", + env.swift_built_lib_dir.display() + ); + let xcode_path = diva::Command::parse("xcode-select --print-path") + .run_and_wait_for_trimmed() + .unwrap_or_else(|_| "/Applications/Xcode.app/Contents/Developer".to_owned()); + println!("cargo:rustc-link-search={xcode_path}/Toolchains/XcodeDefault.xctoolchain/usr/lib/swift/macosx/"); + println!("cargo:rustc-link-search=/usr/lib/swift"); +} + +fn main() -> anyhow::Result<()> { + println!("cargo:rerun-if-env-changed={XCODE_CONFIGURATION_ENV}"); + let env = Env::gather(); + gen_bridges(&env); + compile_swift(&env)?; + link_swift(&env); + Ok(()) } diff --git a/rust/connlib/clients/apple/connlib.xcodeproj/project.pbxproj b/rust/connlib/clients/apple/connlib.xcodeproj/project.pbxproj index cfa75143c..9af2e33ab 100644 --- a/rust/connlib/clients/apple/connlib.xcodeproj/project.pbxproj +++ b/rust/connlib/clients/apple/connlib.xcodeproj/project.pbxproj @@ -7,6 +7,7 @@ objects = { /* Begin PBXBuildFile section */ + 1DAA1D872A3142BC00D84E07 /* BridgingHeader.h in Headers */ = {isa = PBXBuildFile; fileRef = 1DAA1D862A3142BC00D84E07 /* BridgingHeader.h */; settings = {ATTRIBUTES = (Public, ); }; }; 8D46EDDF29DBC29800FF01CA /* Adapter.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8D46EDD729DBC29800FF01CA /* Adapter.swift */; }; 8D46EDE029DBC29800FF01CA /* CallbackHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8D46EDD829DBC29800FF01CA /* CallbackHandler.swift */; }; 8D967B2B29DBA064000B9D58 /* libconnlib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 8D967B2A29DBA03F000B9D58 /* libconnlib.a */; }; @@ -16,18 +17,17 @@ 8DA207FC29DBD80C00703A4A /* SwiftBridgeCore.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DA207F729DBD80C00703A4A /* SwiftBridgeCore.swift */; }; 8DA207FD29DBD86100703A4A /* SwiftBridgeCore.h in Headers */ = {isa = PBXBuildFile; fileRef = 8DA207F629DBD80C00703A4A /* SwiftBridgeCore.h */; settings = {ATTRIBUTES = (Public, ); }; }; 8DA207FE29DBD86100703A4A /* connlib.h in Headers */ = {isa = PBXBuildFile; fileRef = 8D4BADD129DBD6CC00940F0D /* connlib.h */; settings = {ATTRIBUTES = (Public, ); }; }; - 8DA207FF29DBD86100703A4A /* BridgingHeader.h in Headers */ = {isa = PBXBuildFile; fileRef = 8D46EDD629DBC29800FF01CA /* BridgingHeader.h */; settings = {ATTRIBUTES = (Public, ); }; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 1DAA1D862A3142BC00D84E07 /* BridgingHeader.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BridgingHeader.h; sourceTree = ""; }; 8D209DCE29DBE96B00B68D27 /* Security.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Security.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS16.4.sdk/System/Library/Frameworks/Security.framework; sourceTree = DEVELOPER_DIR; }; - 8D46EDD629DBC29800FF01CA /* BridgingHeader.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BridgingHeader.h; sourceTree = ""; }; 8D46EDD729DBC29800FF01CA /* Adapter.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Adapter.swift; sourceTree = ""; }; 8D46EDD829DBC29800FF01CA /* CallbackHandler.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = CallbackHandler.swift; sourceTree = ""; }; 8D4BADD129DBD6CC00940F0D /* connlib.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = connlib.h; sourceTree = ""; }; 8D7D983129DB8437007B8198 /* connlib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = connlib.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 8D967B2629DB9A3B000B9D58 /* build-rust.sh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.script.sh; path = "build-rust.sh"; sourceTree = ""; }; - 8D967B2A29DBA03F000B9D58 /* libconnlib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libconnlib.a; path = target/universal/debug/libconnlib.a; sourceTree = ""; }; + 8D967B2A29DBA03F000B9D58 /* libconnlib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = libconnlib.a; path = ../../target/universal/debug/libconnlib.a; sourceTree = ""; }; 8DA207F329DBD80C00703A4A /* connlib-apple.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "connlib-apple.swift"; sourceTree = ""; }; 8DA207F429DBD80C00703A4A /* connlib-apple.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "connlib-apple.h"; sourceTree = ""; }; 8DA207F529DBD80C00703A4A /* .gitignore */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = .gitignore; sourceTree = ""; }; @@ -52,7 +52,7 @@ children = ( 8DA207F129DBD80C00703A4A /* Generated */, 8D4BADD129DBD6CC00940F0D /* connlib.h */, - 8D46EDD629DBC29800FF01CA /* BridgingHeader.h */, + 1DAA1D862A3142BC00D84E07 /* BridgingHeader.h */, 8D46EDD729DBC29800FF01CA /* Adapter.swift */, 8D46EDD829DBC29800FF01CA /* CallbackHandler.swift */, ); @@ -130,9 +130,9 @@ buildActionMask = 2147483647; files = ( 8DA207F929DBD80C00703A4A /* connlib-apple.h in Headers */, + 1DAA1D872A3142BC00D84E07 /* BridgingHeader.h in Headers */, 8DA207FD29DBD86100703A4A /* SwiftBridgeCore.h in Headers */, 8DA207FE29DBD86100703A4A /* connlib.h in Headers */, - 8DA207FF29DBD86100703A4A /* BridgingHeader.h in Headers */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -377,7 +377,7 @@ "@executable_path/../Frameworks", "@loader_path/Frameworks", ); - LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/target/universal/debug"; + LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../../../target/universal/debug"; MACOSX_DEPLOYMENT_TARGET = 12.4; MARKETING_VERSION = 1.0; MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; @@ -420,8 +420,8 @@ "@loader_path/Frameworks", ); LIBRARY_SEARCH_PATHS = ( - "$(PROJECT_DIR)/target/universal/release", - "$(PROJECT_DIR)/target/universal/debug", + "$(PROJECT_DIR)/../../../target/universal/release", + "$(PROJECT_DIR)/../../../target/universal/debug", ); MACOSX_DEPLOYMENT_TARGET = 12.4; MARKETING_VERSION = 1.0; diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index c539b79ef..b2f721406 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -1,11 +1,11 @@ -// Swift bridge generated code triggers this below -#![allow(improper_ctypes)] #![cfg(any(target_os = "macos", target_os = "ios"))] +// Swift bridge generated code triggers this below +#![allow(improper_ctypes, non_camel_case_types)] use firezone_client_connlib::{ - Callbacks, Error, ErrorType, ResourceList, Session, SwiftConnlibError, SwiftErrorType, - TunnelAddresses, + Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, }; +use std::sync::Arc; #[swift_bridge::bridge] mod ffi { @@ -14,24 +14,52 @@ mod ffi { resources: String, } - // TODO: Allegedly not FFI safe, but works #[swift_bridge(swift_repr = "struct")] struct TunnelAddresses { address4: String, address6: String, } - #[swift_bridge(already_declared)] - enum SwiftConnlibError {} + // TODO: Duplicating these enum variants from `libs/common/src/error.rs` is + // brittle/noisy/tedious + enum SwiftConnlibError { + Io, + Base64DecodeError, + Base64DecodeSliceError, + RequestError, + PortalConnectionError, + UriError, + SerializeError, + IceError, + IceDataError, + SendChannelError, + ConnectionEstablishError, + WireguardError, + NoRuntime, + UnknownResource, + ControlProtocolError, + IfaceRead, + Other, + InvalidTunnelName, + NetlinkErrorIo, + NoIface, + NoMtu, + } - #[swift_bridge(already_declared)] - enum SwiftErrorType {} + enum SwiftErrorType { + Recoverable, + Fatal, + } extern "Rust" { type WrappedSession; #[swift_bridge(associated_to = WrappedSession)] - fn connect(portal_url: String, token: String) -> Result; + fn connect( + portal_url: String, + token: String, + callback_handler: CallbackHandler, + ) -> Result; #[swift_bridge(swift_name = "bumpSockets")] fn bump_sockets(&self) -> bool; @@ -43,15 +71,62 @@ mod ffi { } extern "Swift" { - type Opaque; - #[swift_bridge(swift_name = "onUpdateResources")] - fn on_update_resources(resourceList: ResourceList); + type CallbackHandler; - #[swift_bridge(swift_name = "onSetTunnelAddresses")] - fn on_set_tunnel_addresses(tunnelAddresses: TunnelAddresses); + #[swift_bridge(swift_name = "onUpdateResources")] + fn on_update_resources(&self, resourceList: ResourceList); + + #[swift_bridge(swift_name = "onConnect")] + fn on_connect(&self, tunnelAddresses: TunnelAddresses); + + #[swift_bridge(swift_name = "onDisconnect")] + fn on_disconnect(&self); #[swift_bridge(swift_name = "onError")] - fn on_error(error: SwiftConnlibError, error_type: SwiftErrorType); + fn on_error(&self, error: SwiftConnlibError, error_type: SwiftErrorType); + } +} + +impl<'a> From<&'a Error> for ffi::SwiftConnlibError { + fn from(val: &'a Error) -> Self { + match val { + Error::Io(..) => Self::Io, + Error::Base64DecodeError(..) => Self::Base64DecodeError, + Error::Base64DecodeSliceError(..) => Self::Base64DecodeSliceError, + Error::RequestError(..) => Self::RequestError, + Error::PortalConnectionError(..) => Self::PortalConnectionError, + Error::UriError => Self::UriError, + Error::SerializeError(..) => Self::SerializeError, + Error::IceError(..) => Self::IceError, + Error::IceDataError(..) => Self::IceDataError, + Error::SendChannelError => Self::SendChannelError, + Error::ConnectionEstablishError => Self::ConnectionEstablishError, + Error::WireguardError(..) => Self::WireguardError, + Error::NoRuntime => Self::NoRuntime, + Error::UnknownResource => Self::UnknownResource, + Error::ControlProtocolError => Self::ControlProtocolError, + Error::IfaceRead(..) => Self::IfaceRead, + Error::Other(..) => Self::Other, + Error::InvalidTunnelName => Self::InvalidTunnelName, + Error::NetlinkErrorIo(_) => Self::NetlinkErrorIo, + Error::NoIface => Self::NoIface, + Error::NoMtu => Self::NoMtu, + } + } +} + +impl From for ffi::SwiftConnlibError { + fn from(val: Error) -> Self { + (&val).into() + } +} + +impl From for ffi::SwiftErrorType { + fn from(val: ErrorType) -> Self { + match val { + ErrorType::Recoverable => Self::Recoverable, + ErrorType::Fatal => Self::Fatal, + } } } @@ -77,26 +152,49 @@ pub struct WrappedSession { session: Session, } -struct CallbackHandler; +// SAFETY: `CallbackHandler.swift` promises to be thread-safe. +// TODO: Uphold that promise! +unsafe impl Send for ffi::CallbackHandler {} +unsafe impl Sync for ffi::CallbackHandler {} + +#[derive(Clone)] +#[repr(transparent)] +// Generated Swift opaque type wrappers have a `Drop` impl that decrements the +// refcount, but there's no way to generate a `Clone` impl that increments the +// recount. Instead, we just wrap it in an `Arc`. +pub struct CallbackHandler(Arc); impl Callbacks for CallbackHandler { - fn on_update_resources(resource_list: ResourceList) { - ffi::on_update_resources(resource_list.into()); + fn on_update_resources(&self, resource_list: ResourceList) { + self.0.on_update_resources(resource_list.into()) } - fn on_set_tunnel_adresses(tunnel_addresses: TunnelAddresses) { - ffi::on_set_tunnel_addresses(tunnel_addresses.into()); + fn on_connect(&self, tunnel_addresses: TunnelAddresses) { + self.0.on_connect(tunnel_addresses.into()) } - fn on_error(error: &Error, error_type: ErrorType) { - ffi::on_error(error.into(), error_type.into()); + fn on_disconnect(&self) { + self.0.on_disconnect() + } + + fn on_error(&self, error: &Error, error_type: ErrorType) { + self.0.on_error(error.into(), error_type.into()) } } impl WrappedSession { - fn connect(portal_url: String, token: String) -> Result { - let session = Session::connect::(portal_url.as_str(), token)?; - Ok(Self { session }) + fn connect( + portal_url: String, + token: String, + callback_handler: ffi::CallbackHandler, + ) -> Result { + Ok(Self { + session: Session::connect( + portal_url.as_str(), + token, + CallbackHandler(callback_handler.into()), + )?, + }) } fn bump_sockets(&self) -> bool { diff --git a/rust/connlib/clients/headless/src/main.rs b/rust/connlib/clients/headless/src/main.rs index fb8c3c3f2..7711804a3 100644 --- a/rust/connlib/clients/headless/src/main.rs +++ b/rust/connlib/clients/headless/src/main.rs @@ -7,18 +7,23 @@ use firezone_client_connlib::{ }; use url::Url; -enum CallbackHandler {} +#[derive(Clone)] +pub struct CallbackHandler; impl Callbacks for CallbackHandler { - fn on_update_resources(_resource_list: ResourceList) { + fn on_update_resources(&self, _resource_list: ResourceList) { todo!() } - fn on_set_tunnel_adresses(_tunnel_addresses: TunnelAddresses) { + fn on_connect(&self, _tunnel_addresses: TunnelAddresses) { todo!() } - fn on_error(error: &Error, error_type: ErrorType) { + fn on_disconnect(&self) { + todo!() + } + + fn on_error(&self, error: &Error, error_type: ErrorType) { match error_type { ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"), ErrorType::Fatal => panic!("Encountered fatal error: {error}"), @@ -40,8 +45,7 @@ fn main() -> Result<()> { // TODO: allow passing as arg vars let url = parse_env_var::(URL_ENV_VAR)?; let secret = parse_env_var::(SECRET_ENV_VAR)?; - // TODO: This is disgusting - let mut session = Session::::connect::(url, secret).unwrap(); + let mut session = Session::connect(url, secret, CallbackHandler).unwrap(); tracing::info!("Started new session"); session.wait_for_ctrl_c().unwrap(); session.disconnect(); diff --git a/rust/connlib/gateway/src/main.rs b/rust/connlib/gateway/src/main.rs index d94f5e4eb..40e45dc9c 100644 --- a/rust/connlib/gateway/src/main.rs +++ b/rust/connlib/gateway/src/main.rs @@ -6,18 +6,23 @@ use firezone_gateway_connlib::{ }; use url::Url; -enum CallbackHandler {} +#[derive(Clone)] +pub struct CallbackHandler; impl Callbacks for CallbackHandler { - fn on_update_resources(_resource_list: ResourceList) { + fn on_update_resources(&self, _resource_list: ResourceList) { todo!() } - fn on_set_tunnel_adresses(_tunnel_addresses: TunnelAddresses) { + fn on_connect(&self, _tunnel_addresses: TunnelAddresses) { todo!() } - fn on_error(error: &Error, error_type: ErrorType) { + fn on_disconnect(&self) { + todo!() + } + + fn on_error(&self, error: &Error, error_type: ErrorType) { match error_type { ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"), ErrorType::Fatal => panic!("Encountered fatal error: {error}"), @@ -33,8 +38,7 @@ fn main() -> Result<()> { // TODO: allow passing as arg vars let url = parse_env_var::(URL_ENV_VAR)?; let secret = parse_env_var::(SECRET_ENV_VAR)?; - // TODO: This is disgusting - let mut session = Session::::connect::(url, secret).unwrap(); + let mut session = Session::connect(url, secret, CallbackHandler).unwrap(); session.wait_for_ctrl_c().unwrap(); session.disconnect(); Ok(()) diff --git a/rust/connlib/libs/client/src/control.rs b/rust/connlib/libs/client/src/control.rs index ee2901966..04edd0651 100644 --- a/rust/connlib/libs/client/src/control.rs +++ b/rust/connlib/libs/client/src/control.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use crate::messages::{Connect, EgressMessages, InitClient, Messages, Relays}; use boringtun::x25519::StaticSecret; @@ -27,10 +27,9 @@ impl ControlSignal for ControlSignaler { } /// Implementation of [ControlSession] for clients. -pub struct ControlPlane { - tunnel: Arc>, +pub struct ControlPlane { + tunnel: Arc>, control_signaler: ControlSignaler, - _phantom: PhantomData, } #[derive(Clone)] @@ -38,10 +37,7 @@ struct ControlSignaler { internal_sender: Arc>, } -impl ControlPlane -where - C: Send + Sync + 'static, -{ +impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] async fn start(mut self, mut receiver: Receiver) { let mut interval = tokio::time::interval(Duration::from_secs(10)); @@ -64,7 +60,7 @@ where ) { if let Err(e) = self.tunnel.set_interface(&interface).await { tracing::error!("Couldn't initialize interface: {e}"); - C::on_error(&e, Fatal); + self.tunnel.callbacks().on_error(&e, Fatal); return; } @@ -89,7 +85,7 @@ where .recieved_offer_response(resource_id, rtc_sdp, gateway_public_key.0.into()) .await { - C::on_error(&e, Recoverable); + self.tunnel.callbacks().on_error(&e, Recoverable); } } @@ -127,12 +123,12 @@ where .await { tunnel.cleanup_connection(resource_id); - C::on_error(&err.into(), Recoverable); + tunnel.callbacks().on_error(&err.into(), Recoverable); } } Err(err) => { tunnel.cleanup_connection(resource_id); - C::on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err, Recoverable); } } }); @@ -157,12 +153,11 @@ where } #[async_trait] -impl ControlSession - for ControlPlane -{ - #[tracing::instrument(level = "trace", skip(private_key))] +impl ControlSession for ControlPlane { + #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, + callbacks: CB, ) -> Result<(Sender, Receiver)> { // This is kinda hacky, the buffer size is 1 so that we make sure that we // process one message at a time, blocking if a previous message haven't been processed @@ -172,12 +167,11 @@ impl ControlSession { + let control_plane = ControlPlane { tunnel, control_signaler, - _phantom: PhantomData, }; tokio::spawn(async move { control_plane.start(receiver).await }); diff --git a/rust/connlib/libs/client/src/lib.rs b/rust/connlib/libs/client/src/lib.rs index f85827552..2cc10699c 100644 --- a/rust/connlib/libs/client/src/lib.rs +++ b/rust/connlib/libs/client/src/lib.rs @@ -9,13 +9,17 @@ mod messages; /// Session type for clients. /// /// For more information see libs_common docs on [Session][libs_common::Session]. -pub type Session = - libs_common::Session, IngressMessages, EgressMessages, ReplyMessages, Messages>; +pub type Session = libs_common::Session< + ControlPlane, + IngressMessages, + EgressMessages, + ReplyMessages, + Messages, + CB, +>; pub use libs_common::{ - error::SwiftConnlibError, - error_type::{ErrorType, SwiftErrorType}, - get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses, + error_type::ErrorType, get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses, }; use messages::Messages; use messages::ReplyMessages; diff --git a/rust/connlib/libs/common/Cargo.toml b/rust/connlib/libs/common/Cargo.toml index aad27970c..f72b58946 100644 --- a/rust/connlib/libs/common/Cargo.toml +++ b/rust/connlib/libs/common/Cargo.toml @@ -27,8 +27,6 @@ ip_network = { version = "0.4", default-features = false, features = ["serde"] } boringtun = { workspace = true } os_info = { version = "3", default-features = false } -macros = { path = "../../macros" } - [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] swift-bridge = { workspace = true } diff --git a/rust/connlib/libs/common/src/error.rs b/rust/connlib/libs/common/src/error.rs index 6daf38838..df7006a4e 100644 --- a/rust/connlib/libs/common/src/error.rs +++ b/rust/connlib/libs/common/src/error.rs @@ -1,14 +1,13 @@ //! Error module. use base64::{DecodeError, DecodeSliceError}; use boringtun::noise::errors::WireGuardError; -use macros::SwiftEnum; use thiserror::Error; /// Unified Result type to use across connlib. pub type Result = std::result::Result; /// Unified error type to use across connlib. -#[derive(Error, Debug, SwiftEnum)] +#[derive(Error, Debug)] pub enum ConnlibError { /// Standard IO error. #[error(transparent)] @@ -80,10 +79,6 @@ pub enum ConnlibError { NoMtu, } -/// Type auto-generated by [SwiftEnum] intended to be used with rust-swift-bridge. -/// All the variants come from [ConnlibError], reference that for documentation. -pub use swift_ffi::SwiftConnlibError; - #[cfg(target_os = "linux")] impl From for ConnlibError { fn from(err: rtnetlink::Error) -> Self { diff --git a/rust/connlib/libs/common/src/error_type.rs b/rust/connlib/libs/common/src/error_type.rs index 7f411c87f..a0e99638f 100644 --- a/rust/connlib/libs/common/src/error_type.rs +++ b/rust/connlib/libs/common/src/error_type.rs @@ -1,10 +1,10 @@ //! Module that contains the Error-Type that hints how to handle an error to upper layers. -use macros::SwiftEnum; + /// This indicates whether the produced error is something recoverable or fatal. /// Fata/Recoverable only indicates how to handle the error for the client. /// /// Any of the errors in [ConnlibError][crate::error::ConnlibError] could be of any [ErrorType] depending the circumstance. -#[derive(Debug, Clone, Copy, SwiftEnum)] +#[derive(Debug, Clone, Copy)] pub enum ErrorType { /// Recoverable means that the session can continue /// e.g. Failed to send an SDP @@ -14,7 +14,3 @@ pub enum ErrorType { /// e.g. Max number of retries was reached when trying to connect to the portal. Fatal, } - -/// Auto generated enum by [SwiftEnum], all variants come from [ErrorType] -/// reference that for docs. -pub use swift_ffi::SwiftErrorType; diff --git a/rust/connlib/libs/common/src/session.rs b/rust/connlib/libs/common/src/session.rs index f48e0397f..8a688a9ff 100644 --- a/rust/connlib/libs/common/src/session.rs +++ b/rust/connlib/libs/common/src/session.rs @@ -17,9 +17,9 @@ use crate::{control::PhoenixChannel, error_type::ErrorType, messages::Key, Error // TODO: Not the most tidy trait for a control-plane. /// Trait that represents a control-plane. #[async_trait] -pub trait ControlSession { +pub trait ControlSession { /// Start control-plane with the given private-key in the background. - async fn start(private_key: StaticSecret) -> Result<(Sender, Receiver)>; + async fn start(private_key: StaticSecret, callbacks: CB) -> Result<(Sender, Receiver)>; /// Either "gateway" or "client" used to get the control-plane URL. fn socket_path() -> &'static str; @@ -31,9 +31,9 @@ pub trait ControlSession { /// A session is the entry-point for connlib, maintains the runtime and the tunnel. /// /// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. -pub struct Session { +pub struct Session { runtime: Option, - _phantom: PhantomData<(T, U, V, R, M)>, + _phantom: PhantomData<(T, U, V, R, M, CB)>, } /// Resource list that will be displayed to the users. @@ -51,38 +51,41 @@ pub struct TunnelAddresses { // Evaluate doing this not static /// Traits that will be used by connlib to callback the client upper layers. -pub trait Callbacks { +pub trait Callbacks: Clone + Send + Sync { /// Called when there's a change in the resource list. - fn on_update_resources(resource_list: ResourceList); + fn on_update_resources(&self, resource_list: ResourceList); /// Called when the tunnel address is set. - fn on_set_tunnel_adresses(tunnel_addresses: TunnelAddresses); + fn on_connect(&self, tunnel_addresses: TunnelAddresses); + /// Called when the tunnel is disconnected. + fn on_disconnect(&self); /// Called when there's an error. /// /// # Parameters /// - `error`: The actual error that happened. /// - `error_type`: Whether the error should terminate the session or not. - fn on_error(error: &Error, error_type: ErrorType); + fn on_error(&self, error: &Error, error_type: ErrorType); } macro_rules! fatal_error { - ($result:expr, $c:ty) => { + ($result:expr, $c:expr) => { match $result { Ok(res) => res, Err(e) => { - <$c>::on_error(&e, ErrorType::Fatal); + $c.on_error(&e, ErrorType::Fatal); return; } } }; } -impl Session +impl Session where - T: ControlSession, + T: ControlSession, U: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static, R: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static, V: serde::Serialize + Send + 'static, M: From + From + Send + 'static + std::fmt::Debug, + CB: Callbacks + 'static, { /// Block on waiting for ctrl+c to terminate the runtime. /// (Used for the gateways). @@ -103,11 +106,11 @@ where /// 2. Connect to the control plane to the portal /// 3. Start the tunnel in the background and forward control plane messages to it. /// - /// The generic parameter `C` should implement all the handlers and that's how errors will be surfaced. + /// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced. /// /// On a fatal error you should call `[Session::disconnect]` and start a new one. // TODO: token should be something like SecretString but we need to think about FFI compatibility - pub fn connect(portal_url: impl TryInto, token: String) -> Result { + pub fn connect(portal_url: impl TryInto, token: String, callbacks: CB) -> Result { // TODO: We could use tokio::runtime::current() to get the current runtime // which could work with swif-rust that already runs a runtime. But IDK if that will work // in all pltaforms, a couple of new threads shouldn't bother none. @@ -124,9 +127,9 @@ where let private_key = StaticSecret::random_from_rng(OsRng); let self_id = uuid::Uuid::new_v4(); - let connect_url = fatal_error!(get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string()), C); + let connect_url = fatal_error!(get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string()), callbacks); - let (sender, mut receiver) = fatal_error!(T::start(private_key).await, C); + let (sender, mut receiver) = fatal_error!(T::start(private_key, callbacks.clone()).await, callbacks); let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg| { let sender = sender.clone(); @@ -150,15 +153,15 @@ where if let Some(t) = exponential_backoff.next_backoff() { tracing::warn!("Error during connection to the portal, retrying in {} seconds", t.as_secs()); match result { - Ok(()) => C::on_error(&tokio_tungstenite::tungstenite::Error::ConnectionClosed.into(), ErrorType::Recoverable), - Err(e) => C::on_error(&e, ErrorType::Recoverable) + Ok(()) => callbacks.on_error(&tokio_tungstenite::tungstenite::Error::ConnectionClosed.into(), ErrorType::Recoverable), + Err(e) => callbacks.on_error(&e, ErrorType::Recoverable) } tokio::time::sleep(t).await; } else { tracing::error!("Connection to the portal error, check your internet or the status of the portal.\nDisconnecting interface."); match result { - Ok(()) => C::on_error(&crate::Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed), ErrorType::Fatal), - Err(e) => C::on_error(&e, ErrorType::Fatal) + Ok(()) => callbacks.on_error(&crate::Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed), ErrorType::Fatal), + Err(e) => callbacks.on_error(&e, ErrorType::Fatal) } break; } diff --git a/rust/connlib/libs/gateway/src/control.rs b/rust/connlib/libs/gateway/src/control.rs index 6537be258..c7875333e 100644 --- a/rust/connlib/libs/gateway/src/control.rs +++ b/rust/connlib/libs/gateway/src/control.rs @@ -17,8 +17,8 @@ use async_trait::async_trait; const INTERNAL_CHANNEL_SIZE: usize = 256; -pub struct ControlPlane { - tunnel: Arc>, +pub struct ControlPlane { + tunnel: Arc>, control_signaler: ControlSignaler, } @@ -35,10 +35,7 @@ impl ControlSignal for ControlSignaler { } } -impl ControlPlane -where - C: Send + Sync + 'static, -{ +impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] async fn start(mut self, mut receiver: Receiver) { let mut interval = tokio::time::interval(Duration::from_secs(10)); @@ -55,7 +52,7 @@ where async fn init(&mut self, init: InitGateway) { if let Err(e) = self.tunnel.set_interface(&init.interface).await { tracing::error!("Couldn't initialize interface: {e}"); - C::on_error(&e, Fatal); + self.tunnel.callbacks().on_error(&e, Fatal); return; } @@ -87,12 +84,12 @@ where .await { tunnel.cleanup_peer_connection(connection_request.device.id); - C::on_error(&err.into(), Recoverable); + tunnel.callbacks().on_error(&err.into(), Recoverable); } } Err(err) => { tunnel.cleanup_peer_connection(connection_request.device.id); - C::on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err, Recoverable); } } }); @@ -123,13 +120,13 @@ where } #[async_trait] -impl ControlSession for ControlPlane -where - C: Send + Sync + 'static, +impl ControlSession + for ControlPlane { - #[tracing::instrument(level = "trace", skip(private_key))] + #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, + callbacks: CB, ) -> Result<(Sender, Receiver)> { // This is kinda hacky, the buffer size is 1 so that we make sure that we // process one message at a time, blocking if a previous message haven't been processed @@ -140,7 +137,7 @@ where let (internal_sender, internal_receiver) = channel(INTERNAL_CHANNEL_SIZE); let internal_sender = Arc::new(internal_sender); let control_signaler = ControlSignaler { internal_sender }; - let tunnel = Arc::new(Tunnel::<_, C>::new(private_key, control_signaler.clone()).await?); + let tunnel = Arc::new(Tunnel::new(private_key, control_signaler.clone(), callbacks).await?); let control_plane = ControlPlane { tunnel, diff --git a/rust/connlib/libs/gateway/src/lib.rs b/rust/connlib/libs/gateway/src/lib.rs index 6fa63b7cc..41e3394f7 100644 --- a/rust/connlib/libs/gateway/src/lib.rs +++ b/rust/connlib/libs/gateway/src/lib.rs @@ -10,12 +10,13 @@ mod messages; /// /// For more information see libs_common docs on [Session][libs_common::Session]. // TODO: Still working on gateway messages -pub type Session = libs_common::Session< - ControlPlane, +pub type Session = libs_common::Session< + ControlPlane, IngressMessages, EgressMessages, IngressMessages, IngressMessages, + CB, >; pub use libs_common::{error_type::ErrorType, Callbacks, Error, ResourceList, TunnelAddresses}; diff --git a/rust/connlib/libs/tunnel/src/control_protocol.rs b/rust/connlib/libs/tunnel/src/control_protocol.rs index eedc05c63..7d4d006e2 100644 --- a/rust/connlib/libs/tunnel/src/control_protocol.rs +++ b/rust/connlib/libs/tunnel/src/control_protocol.rs @@ -21,10 +21,10 @@ use webrtc::{ use crate::{peer::Peer, ControlSignal, PeerConfig, Tunnel}; -impl Tunnel +impl Tunnel where - C: Send + Sync + 'static, - CB: Send + Sync + 'static, + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, { async fn handle_channel_open( self: &Arc, @@ -160,7 +160,7 @@ where let Some(gateway_public_key) = tunnel.gateway_public_keys.lock().remove(&resource_id) else { tunnel.cleanup_connection(resource_id); tracing::warn!("Opened ICE channel with gateway without ever receiving public key"); - CB::on_error(&Error::ControlProtocolError, Recoverable); + tunnel.callbacks.on_error(&Error::ControlProtocolError, Recoverable); return; }; let peer_config = PeerConfig { @@ -172,7 +172,7 @@ where if let Err(e) = tunnel.handle_channel_open(d, index, peer_config).await { tracing::error!("Couldn't establish wireguard link after channel was opened: {e}"); - CB::on_error(&e, Recoverable); + tunnel.callbacks.on_error(&e, Recoverable); tunnel.cleanup_connection(resource_id); } tunnel.awaiting_connection.lock().remove(&resource_id); @@ -279,7 +279,7 @@ where Box::pin(async move { if let Err(e) = tunnel.handle_channel_open(data_channel, index, peer).await { - CB::on_error(&e, Recoverable); + tunnel.callbacks.on_error(&e, Recoverable); tracing::error!( "Couldn't establish wireguard link after opening channel: {e}" ); diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index 9f91aa3fa..d7b6735f6 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -34,7 +34,6 @@ use webrtc::{ use std::{ collections::{HashMap, HashSet}, - marker::PhantomData, net::IpAddr, sync::Arc, time::Duration, @@ -145,21 +144,25 @@ pub struct Tunnel { resources: RwLock, control_signaler: C, gateway_public_keys: Mutex>, - _phantom: PhantomData, + callbacks: CB, } -impl Tunnel +impl Tunnel where - C: Send + Sync + 'static, - CB: Send + Sync + 'static, + C: ControlSignal + Send + Sync + 'static, + CB: Callbacks + 'static, { /// Creates a new tunnel. /// /// # Parameters /// - `private_key`: wireguard's private key. /// - `control_signaler`: this is used to send SDP from the tunnel to the control plane. - #[tracing::instrument(level = "trace", skip(private_key, control_signaler))] - pub async fn new(private_key: StaticSecret, control_signaler: C) -> Result { + #[tracing::instrument(level = "trace", skip(private_key, control_signaler, callbacks))] + pub async fn new( + private_key: StaticSecret, + control_signaler: C, + callbacks: CB, + ) -> Result { let public_key = (&private_key).into(); let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT)); let peers_by_ip = RwLock::new(IpNetworkTable::new()); @@ -203,7 +206,7 @@ where resources, awaiting_connection, control_signaler, - _phantom: PhantomData, + callbacks, }) } @@ -217,7 +220,7 @@ where let mut iface_config = self.iface_config.lock().await; for ip in resource_description.ips() { if let Err(err) = iface_config.add_route(ip).await { - CB::on_error(&err, Fatal); + self.callbacks.on_error(&err, Fatal); } } } @@ -247,7 +250,7 @@ where Ok(()) } - async fn peer_refresh(peer: &Peer, dst_buf: &mut [u8; MAX_UDP_SIZE]) { + async fn peer_refresh(&self, peer: &Peer, dst_buf: &mut [u8; MAX_UDP_SIZE]) { let update_timers_result = peer.update_timers(&mut dst_buf[..]); match update_timers_result { @@ -256,7 +259,9 @@ where tracing::error!("Connection expired"); } TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), - TunnResult::WriteToNetwork(packet) => peer.send_infallible::(packet).await, + TunnResult::WriteToNetwork(packet) => { + peer.send_infallible(packet, &self.callbacks).await + } _ => panic!("Unexpected result from update_timers"), }; } @@ -292,7 +297,7 @@ where .collect(); for peer in peers { - Self::peer_refresh(&peer, &mut dst_buf).await; + tunnel.peer_refresh(&peer, &mut dst_buf).await; } interval.tick().await; @@ -336,7 +341,7 @@ where ) { Ok(packet) => packet, Err(TunnResult::WriteToNetwork(cookie)) => { - peer.send_infallible::(cookie).await; + peer.send_infallible(cookie, &tunnel.callbacks).await; continue; } Err(_) => continue, @@ -360,7 +365,7 @@ where TunnResult::Err(_) => continue, TunnResult::WriteToNetwork(packet) => { flush = true; - peer.send_infallible::(packet).await; + peer.send_infallible(packet, &tunnel.callbacks).await; } TunnResult::WriteToTunnelV4(packet, addr) => { if peer.is_allowed(addr) { @@ -380,7 +385,7 @@ where let res = peer.tunnel.lock().decapsulate(None, &[], &mut dst_buf[..]); res } { - peer.send_infallible::(packet).await; + peer.send_infallible(packet, &tunnel.callbacks).await; } } } @@ -389,13 +394,13 @@ where async fn write4_device_infallible(&self, packet: &[u8]) { if let Err(e) = self.device_channel.write4(packet).await { - CB::on_error(&e.into(), Recoverable); + self.callbacks.on_error(&e.into(), Recoverable); } } async fn write6_device_infallible(&self, packet: &[u8]) { if let Err(e) = self.device_channel.write6(packet).await { - CB::on_error(&e.into(), Recoverable); + self.callbacks.on_error(&e.into(), Recoverable); } } @@ -425,13 +430,13 @@ where Ok(res) => res, Err(err) => { tracing::error!("Couldn't read packet from interface: {err}"); - CB::on_error(&err.into(), Recoverable); + dev.callbacks.on_error(&err.into(), Recoverable); continue; } }, Err(err) => { tracing::error!("Couldn't obtain iface mtu: {err}"); - CB::on_error(&err, Recoverable); + dev.callbacks.on_error(&err, Recoverable); continue; } } @@ -472,7 +477,7 @@ where // Not a deadlock because this is a different task dev.awaiting_connection.lock().remove(&id); tracing::error!("couldn't start protocol for new connection to resource: {e}"); - CB::on_error(&e, Recoverable); + dev.callbacks.on_error(&e, Recoverable); } }); } @@ -490,13 +495,13 @@ where } TunnResult::Err(e) => { tracing::error!(message = "Encapsulate error for resource corresponding to {dst_addr}", error = ?e); - CB::on_error(&e.into(), Recoverable); + dev.callbacks.on_error(&e.into(), Recoverable); } TunnResult::WriteToNetwork(packet) => { tracing::trace!("writing iface packet to peer: {dst_addr}"); if let Err(e) = channel.write(&Bytes::copy_from_slice(packet)).await { tracing::error!("Couldn't write packet to channel: {e}"); - CB::on_error(&e.into(), Recoverable); + dev.callbacks.on_error(&e.into(), Recoverable); } } _ => panic!("Unexpected result from encapsulate"), @@ -508,4 +513,8 @@ where fn next_index(&self) -> u32 { self.next_index.lock().next() } + + pub fn callbacks(&self) -> &CB { + &self.callbacks + } } diff --git a/rust/connlib/libs/tunnel/src/peer.rs b/rust/connlib/libs/tunnel/src/peer.rs index ac38ec64a..740119d5f 100644 --- a/rust/connlib/libs/tunnel/src/peer.rs +++ b/rust/connlib/libs/tunnel/src/peer.rs @@ -21,10 +21,10 @@ pub(crate) struct Peer { } impl Peer { - pub(crate) async fn send_infallible(&self, data: &[u8]) { + pub(crate) async fn send_infallible(&self, data: &[u8], callbacks: &CB) { if let Err(e) = self.channel.write(&Bytes::copy_from_slice(data)).await { tracing::error!("Couldn't send packet to connected peer: {e}"); - CB::on_error(&e.into(), ErrorType::Recoverable); + callbacks.on_error(&e.into(), ErrorType::Recoverable); } } diff --git a/rust/connlib/libs/tunnel/src/tun_android.rs b/rust/connlib/libs/tunnel/src/tun_android.rs index d47d94a1d..29c244a34 100644 --- a/rust/connlib/libs/tunnel/src/tun_android.rs +++ b/rust/connlib/libs/tunnel/src/tun_android.rs @@ -1,20 +1,87 @@ use super::InterfaceConfig; -use libs_common::Result; +use ip_network::IpNetwork; +use libc::{close, open, O_RDWR}; +use libs_common::{Error, Result}; +use std::{ + os::fd::{AsRawFd, RawFd}, + sync::Arc, +}; #[derive(Debug)] pub(crate) struct IfaceConfig(pub(crate) Arc); #[derive(Debug)] -pub(crate) struct IfaceDevice; +pub(crate) struct IfaceDevice { + fd: RawFd, +} -impl IfaceConfig { - // It's easier to not make these functions async, setting these should not block the thread for too long - #[tracing::instrument(level = "trace", skip(self))] - pub fn set_iface_config(&mut self, _config: &InterfaceConfig) -> Result<()> { - todo!() - } - - pub fn up(&mut self) -> Result<()> { - todo!() +impl AsRawFd for IfaceDevice { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for IfaceDevice { + fn drop(&mut self) { + unsafe { close(self.fd) }; + } +} + +impl IfaceDevice { + fn write(&self, _buf: &[u8]) -> usize { + tracing::error!("`write` unimplemented on Android"); + 0 + } + + pub async fn new(_name: &str) -> Result { + // TODO: This won't actually work for non-root users... + let fd = unsafe { open(b"/dev/net/tun\0".as_ptr() as _, O_RDWR) }; + // TODO: everything! + if fd == -1 { + Err(Error::Io(std::io::Error::last_os_error())) + } else { + Ok(Self { fd }) + } + } + + pub fn set_non_blocking(self) -> Result { + tracing::error!("`set_non_blocking` unimplemented on Android"); + Ok(self) + } + + pub async fn mtu(&self) -> Result { + tracing::error!("`mtu` unimplemented on Android"); + Ok(0) + } + + pub fn write4(&self, src: &[u8]) -> usize { + self.write(src) + } + + pub fn write6(&self, src: &[u8]) -> usize { + self.write(src) + } + + pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8]> { + tracing::error!("`read` unimplemented on Android"); + Ok(dst) + } +} + +impl IfaceConfig { + pub async fn add_route(&mut self, route: IpNetwork) -> Result<()> { + tracing::error!("`add_route` unimplemented on Android: `{route:#?}`"); + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + pub async fn set_iface_config(&mut self, _config: &InterfaceConfig) -> Result<()> { + tracing::error!("`set_iface_config` unimplemented on Android: `{_config:#?}`"); + Ok(()) + } + + pub async fn up(&mut self) -> Result<()> { + tracing::error!("`up` unimplemented on Android"); + Ok(()) } } diff --git a/rust/connlib/libs/tunnel/src/tun_darwin.rs b/rust/connlib/libs/tunnel/src/tun_darwin.rs index 10936e8d7..62ae6740d 100644 --- a/rust/connlib/libs/tunnel/src/tun_darwin.rs +++ b/rust/connlib/libs/tunnel/src/tun_darwin.rs @@ -262,21 +262,21 @@ impl IfaceDevice { // So, these functions take a mutable &self, this is not necessary in theory but it's correct! impl IfaceConfig { + pub async fn add_route(&mut self, route: IpNetwork) -> Result<()> { + tracing::error!("`add_route` unimplemented on macOS: `{route:#?}`"); + Ok(()) + } + #[tracing::instrument(level = "trace", skip(self))] pub async fn set_iface_config(&mut self, config: &InterfaceConfig) -> Result<()> { - // TODO - + tracing::error!("`set_iface_config` unimplemented on macOS: `{config:#?}`"); Ok(()) } pub async fn up(&mut self) -> Result<()> { - // TODO + tracing::error!("`up` unimplemented on macOS"); Ok(()) } - - pub async fn add_route(&mut self, route: IpNetwork) -> Result<()> { - todo!() - } } fn get_last_error() -> Error { diff --git a/rust/connlib/libs/tunnel/src/tun_win.rs b/rust/connlib/libs/tunnel/src/tun_win.rs index 9edce1c12..c7a458f3a 100644 --- a/rust/connlib/libs/tunnel/src/tun_win.rs +++ b/rust/connlib/libs/tunnel/src/tun_win.rs @@ -16,7 +16,7 @@ impl IfaceConfig { todo!() } - pub async fn add_route(&mut self, route: IpNetwork) -> Result<()> { + pub async fn add_route(&mut self, _route: IpNetwork) -> Result<()> { todo!() } } diff --git a/rust/connlib/macros/Cargo.toml b/rust/connlib/macros/Cargo.toml deleted file mode 100644 index 5326335be..000000000 --- a/rust/connlib/macros/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "macros" -version = "0.1.0" -edition = "2021" - -[lib] -proc-macro = true - -[dependencies] -syn = { version = "2.0" } -proc-macro2 = { version = "1.0" } -quote = { version = "1.0" } diff --git a/rust/connlib/macros/src/lib.rs b/rust/connlib/macros/src/lib.rs deleted file mode 100644 index 15044226b..000000000 --- a/rust/connlib/macros/src/lib.rs +++ /dev/null @@ -1,108 +0,0 @@ -#![recursion_limit = "128"] - -extern crate proc_macro; -use proc_macro2::{Span, TokenStream}; -use quote::quote; -use syn::{Data, DeriveInput, Fields}; - -/// Macro that generates a new enum with only the discriminants of another enum within a module that implements swift_bridge. -/// -/// This is a workaround to create an error type compatible with swift that can be converted from the original error type. -/// it implements `From` so the idea is that you can call a swift ffi function `handle_error(err.into());` -/// -/// This makes a lot of assumption about the types it's being implemented on since we're controlling the type it is not meant -/// to be a public macro. (However be careful if you reuse it somewhere else! this is based in strum's EnumDiscrminant so you can -/// check there for an actual proper implementation). -/// -/// IMPORTANT!: You need to include swift_bridge::bridge for macos and ios target so this doesn't error out. -#[proc_macro_derive(SwiftEnum)] -pub fn swift_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ast = syn::parse_macro_input!(input as DeriveInput); - - let toks = swift_enum_inner(&ast).unwrap_or_else(|err| err.to_compile_error()); - toks.into() -} - -fn swift_enum_inner(ast: &DeriveInput) -> syn::Result { - let name = &ast.ident; - let vis = &ast.vis; - - let variants = match &ast.data { - Data::Enum(v) => &v.variants, - _ => { - return Err(syn::Error::new( - Span::call_site(), - "This macro only support enums.", - )) - } - }; - - let discriminants: Vec<_> = variants - .into_iter() - .map(|v| { - let ident = &v.ident; - quote! {#ident} - }) - .collect(); - - let enum_name = syn::Ident::new(&format!("Swift{}", name), Span::call_site()); - let mod_name = syn::Ident::new("swift_ffi", Span::call_site()); - - let arms = variants - .iter() - .map(|variant| { - let ident = &variant.ident; - let params = match &variant.fields { - Fields::Unit => quote! {}, - Fields::Unnamed(_fields) => { - quote! { (..) } - } - Fields::Named(_fields) => { - quote! { { .. } } - } - }; - - quote! { #name::#ident #params => #mod_name::#enum_name::#ident } - }) - .collect::>(); - - let from_fn_body = quote! { match val { #(#arms),* } }; - - let impl_from_ref = { - quote! { - impl<'a> ::core::convert::From<&'a #name> for #mod_name::#enum_name { - fn from(val: &'a #name) -> Self { - #from_fn_body - } - } - } - }; - - let impl_from = { - quote! { - impl ::core::convert::From<#name> for #mod_name::#enum_name { - fn from(val: #name) -> Self { - #from_fn_body - } - } - } - }; - - // If we wanted to expose this function we should have another crate that actually also includes - // swift_bridge. but since we are only using this inside our crates we can just make sure we include it. - Ok(quote! { - #[cfg_attr(any(target_os = "macos", target_os = "ios"), swift_bridge::bridge)] - #vis mod #mod_name { - pub enum #enum_name { - #(#discriminants),* - } - - } - - #[cfg(any(target_os = "macos", target_os = "ios"))] - #impl_from_ref - - #[cfg(any(target_os = "macos", target_os = "ios"))] - #impl_from - }) -}