diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs index c94bb984c2d..035d9df06cf 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs @@ -1,6 +1,3 @@ -use std::ffi::OsString; -use std::os::windows::ffi::OsStrExt; - use std::fs::{create_dir_all, OpenOptions}; use std::io::Write; use std::path::Path; @@ -10,6 +7,8 @@ use windows::Win32::Foundation::*; use windows::Win32::System::LibraryLoader::*; use windows_core::*; +use crate::com_buffer::ComBuffer; + pub unsafe fn delay_load(library: PCSTR, function: PCSTR) -> Option { let library = LoadLibraryExA(library, None, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); @@ -28,32 +27,27 @@ pub unsafe fn delay_load(library: PCSTR, function: PCSTR) -> Option { None } +/// Trait for converting strings to Windows-compatible wide strings using COM allocation pub trait WindowsString { - fn into_win_utf8(self: Self) -> (*mut u8, u32); - fn into_win_utf16(self: Self) -> (*mut u16, u32); - fn into_win_utf16_wide(self: Self) -> (*mut u16, u32); + /// Converts to null-terminated UTF-16 using COM allocation + fn to_com_utf16(&self) -> (*mut u16, u32); + /// Converts to Vec for temporary use (caller must keep Vec alive) + fn to_utf16(&self) -> Vec; } -impl WindowsString for String { - fn into_win_utf8(self: Self) -> (*mut u8, u32) { - let mut v = self.into_bytes(); - v.push(0); - - (v.as_mut_ptr(), v.len() as u32) +impl WindowsString for str { + fn to_com_utf16(&self) -> (*mut u16, u32) { + let mut wide_vec: Vec = self.encode_utf16().collect(); + wide_vec.push(0); // null terminator + let wide_bytes: Vec = wide_vec.iter().flat_map(|&x| x.to_le_bytes()).collect(); + let (ptr, byte_count) = ComBuffer::from_buffer(&wide_bytes); + (ptr as *mut u16, byte_count) } - fn into_win_utf16(self: Self) -> (*mut u16, u32) { - let mut v: Vec = self.encode_utf16().collect(); - v.push(0); - - (v.as_mut_ptr(), v.len() as u32) - } - - fn into_win_utf16_wide(self: Self) -> (*mut u16, u32) { - let mut v: Vec = OsString::from(self).encode_wide().collect(); - v.push(0); - - (v.as_mut_ptr(), v.len() as u32) + fn to_utf16(&self) -> Vec { + let mut wide_vec: Vec = self.encode_utf16().collect(); + wide_vec.push(0); // null terminator + wide_vec } } diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/webauthn.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/webauthn.rs index 0c2214d1a79..5fd8fcc7832 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/webauthn.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/webauthn.rs @@ -7,7 +7,7 @@ use windows_core::*; -use crate::util::{debug_log, delay_load}; +use crate::util::{debug_log, delay_load, WindowsString}; use crate::com_buffer::ComBuffer; /// Windows WebAuthn Authenticator Options structure @@ -72,9 +72,6 @@ impl ExperimentalWebAuthnPluginCredentialDetails { user_name: String, user_display_name: String, ) -> Self { - use std::ffi::OsString; - use std::os::windows::ffi::OsStrExt; - // Convert credential_id bytes to hex string, then allocate with COM let credential_id_string = hex::encode(&credential_id); let (credential_id_pointer, credential_id_byte_count) = ComBuffer::from_buffer(credential_id_string.as_bytes()); @@ -83,36 +80,21 @@ impl ExperimentalWebAuthnPluginCredentialDetails { let user_id_string = hex::encode(&user_id); let (user_id_pointer, user_id_byte_count) = ComBuffer::from_buffer(user_id_string.as_bytes()); - // Convert strings to null-terminated wide strings and allocate with COM - let mut rpid_vec: Vec = OsString::from(rpid).encode_wide().collect(); - rpid_vec.push(0); - let rpid_bytes: Vec = rpid_vec.iter().flat_map(|&x| x.to_le_bytes()).collect(); - let (rpid_ptr, _) = ComBuffer::from_buffer(rpid_bytes); - - let mut rp_friendly_name_vec: Vec = OsString::from(rp_friendly_name).encode_wide().collect(); - rp_friendly_name_vec.push(0); - let rp_friendly_name_bytes: Vec = rp_friendly_name_vec.iter().flat_map(|&x| x.to_le_bytes()).collect(); - let (rp_friendly_name_ptr, _) = ComBuffer::from_buffer(rp_friendly_name_bytes); - - let mut user_name_vec: Vec = OsString::from(user_name).encode_wide().collect(); - user_name_vec.push(0); - let user_name_bytes: Vec = user_name_vec.iter().flat_map(|&x| x.to_le_bytes()).collect(); - let (user_name_ptr, _) = ComBuffer::from_buffer(user_name_bytes); - - let mut user_display_name_vec: Vec = OsString::from(user_display_name).encode_wide().collect(); - user_display_name_vec.push(0); - let user_display_name_bytes: Vec = user_display_name_vec.iter().flat_map(|&x| x.to_le_bytes()).collect(); - let (user_display_name_ptr, _) = ComBuffer::from_buffer(user_display_name_bytes); + // Convert strings to null-terminated wide strings using trait methods + let (rpid_ptr, _) = rpid.to_com_utf16(); + let (rp_friendly_name_ptr, _) = rp_friendly_name.to_com_utf16(); + let (user_name_ptr, _) = user_name.to_com_utf16(); + let (user_display_name_ptr, _) = user_display_name.to_com_utf16(); Self { credential_id_byte_count, credential_id_pointer, - rpid: rpid_ptr as *mut u16, - rp_friendly_name: rp_friendly_name_ptr as *mut u16, + rpid: rpid_ptr, + rp_friendly_name: rp_friendly_name_ptr, user_id_byte_count, user_id_pointer, - user_name: user_name_ptr as *mut u16, - user_display_name: user_display_name_ptr as *mut u16, + user_name: user_name_ptr, + user_display_name: user_display_name_ptr, } } } @@ -136,7 +118,7 @@ impl ExperimentalWebAuthnPluginCredentialDetailsList { credentials: Vec, ) -> Self { // Convert credentials to COM-allocated pointers - let mut credential_pointers: Vec<*mut ExperimentalWebAuthnPluginCredentialDetails> = credentials + let credential_pointers: Vec<*mut ExperimentalWebAuthnPluginCredentialDetails> = credentials .into_iter() .map(|cred| { // Use COM allocation for each credential struct @@ -160,14 +142,11 @@ impl ExperimentalWebAuthnPluginCredentialDetailsList { std::ptr::null_mut() }; - // Convert CLSID to wide string and allocate with COM - let mut clsid_wide: Vec = clsid.encode_utf16().collect(); - clsid_wide.push(0); // null terminator - let clsid_bytes: Vec = clsid_wide.iter().flat_map(|&x| x.to_le_bytes()).collect(); - let (clsid_ptr, _) = ComBuffer::from_buffer(clsid_bytes); + // Convert CLSID to wide string using trait method + let (clsid_ptr, _) = clsid.to_com_utf16(); Self { - plugin_clsid: clsid_ptr as *mut u16, + plugin_clsid: clsid_ptr, credential_count: credentials_len as u32, credentials: credentials_pointer, } @@ -276,8 +255,7 @@ pub fn get_all_credentials( match result { Some(api) => { // Create the wide string and keep it alive during the API call - let mut clsid_wide: Vec = plugin_clsid.encode_utf16().collect(); - clsid_wide.push(0); // null terminator + let clsid_wide = plugin_clsid.to_utf16(); let mut credentials_list_ptr: *mut ExperimentalWebAuthnPluginCredentialDetailsList = std::ptr::null_mut(); let result = unsafe { api(clsid_wide.as_ptr(), &mut credentials_list_ptr) }; @@ -318,8 +296,7 @@ pub fn remove_all_credentials( Some(api) => { debug_log("Function loaded successfully, calling API..."); // Create the wide string and keep it alive during the API call - let mut clsid_wide: Vec = plugin_clsid.encode_utf16().collect(); - clsid_wide.push(0); // null terminator + let clsid_wide = plugin_clsid.to_utf16(); let result = unsafe { api(clsid_wide.as_ptr()) };