diff --git a/rust/headless-client/src/dns_control/windows.rs b/rust/headless-client/src/dns_control/windows.rs index f7066c569..9afcd8284 100644 --- a/rust/headless-client/src/dns_control/windows.rs +++ b/rust/headless-client/src/dns_control/windows.rs @@ -228,7 +228,9 @@ mod tests { fn dns_control() { let _guard = firezone_logging::test("debug"); - let rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); let mut tun_dev_manager = firezone_bin_shared::TunDeviceManager::new(1280, 1).unwrap(); // Note: num_threads (`1`) is unused on windows. let _tun = tun_dev_manager.make_tun().unwrap(); diff --git a/rust/headless-client/src/ipc_service/windows.rs b/rust/headless-client/src/ipc_service/windows.rs index cd2bbd32f..80c09ac9c 100644 --- a/rust/headless-client/src/ipc_service/windows.rs +++ b/rust/headless-client/src/ipc_service/windows.rs @@ -11,10 +11,16 @@ use std::{ time::Duration, }; use tokio::sync::mpsc; -use windows::Win32::{ - Foundation::{CloseHandle, HANDLE}, - Security::{GetTokenInformation, TokenElevation, TOKEN_ELEVATION, TOKEN_QUERY}, - System::Threading::{GetCurrentProcess, OpenProcessToken}, +use windows::{ + core::PWSTR, + Win32::{ + Foundation::{CloseHandle, HANDLE}, + Security::{ + GetTokenInformation, LookupAccountSidW, TokenElevation, TokenUser, SID_NAME_USE, + TOKEN_ELEVATION, TOKEN_QUERY, TOKEN_USER, + }, + System::Threading::{GetCurrentProcess, OpenProcessToken}, + }, }; use windows_service::{ service::{ @@ -30,8 +36,14 @@ const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; /// Returns true if the IPC service can run properly pub(crate) fn elevation_check() -> Result { - let token = ProcessToken::our_process()?; - let elevated = token.is_elevated()?; + let token = ProcessToken::our_process().context("Failed to get process token")?; + let elevated = token + .is_elevated() + .context("Failed to get elevation status")?; + let username = token.username().context("Failed to get username")?; + + tracing::debug!(%username, %elevated); + Ok(elevated) } @@ -74,6 +86,57 @@ impl ProcessToken { }?; Ok(elevation.TokenIsElevated == 1) } + + fn username(&self) -> Result { + // Normally, the pattern here is to call `GetTokenInformation` with a size of 0 and retrieve the necessary buffer length from the first error. + // This doesn't seem to work in this case so we just allocate a hopefully sufficiently large buffer ahead of time. + let token_user_sz = 1024; + let mut token_user = vec![0u8; token_user_sz as usize]; + let token_user = token_user.as_mut_ptr() as *mut TOKEN_USER; + + let mut return_sz = 0; + + // Fetch the actual user information. + // SAFETY: Docs say nothing about threads or safety + unsafe { + GetTokenInformation( + self.inner, + TokenUser, + Some(token_user as *mut c_void), + token_user_sz, + &mut return_sz, + ) + }?; + + let mut name = vec![0u16; 256]; + let mut domain = vec![0u16; 256]; + let mut name_size = name.len() as u32; + let mut domain_size = domain.len() as u32; + let mut sid_type = SID_NAME_USE::default(); + + // Convert account ID to human-friendly name. + + // SAFETY: We allocated the buffer. + let sid = unsafe { (*token_user).User.Sid }; + + // SAFETY: Docs say nothing about threads or safety + unsafe { + LookupAccountSidW( + None, + sid, + PWSTR::from_raw(name.as_mut_ptr()), + &mut name_size, + PWSTR::from_raw(domain.as_mut_ptr()), + &mut domain_size, + &mut sid_type, + ) + }?; + + let name = String::from_utf16_lossy(&name[..name_size as usize]); + let domain = String::from_utf16_lossy(&domain[..domain_size as usize]); + + Ok(format!("{name}\\{domain}")) + } } impl Drop for ProcessToken { @@ -289,3 +352,17 @@ async fn service_run_async( } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[expect(clippy::print_stdout, reason = "We want to see the output in the test")] + fn get_username_of_current_process() { + let process_token = ProcessToken::our_process().unwrap(); + let username = process_token.username().unwrap(); // If this doesn't panic, we are good. + + println!("Running as user: {username}") + } +}