diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs index 84a4f833f36..42ac43d0df5 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs @@ -1,20 +1,7 @@ use serde_json; -use std::{ - alloc::{alloc, Layout}, - ptr, - sync::Arc, - time::Duration, -}; -use windows::core::{s, HRESULT}; +use std::{sync::Arc, time::Duration}; -use crate::util::{delay_load, wstr_to_string}; -use crate::webauthn::WEBAUTHN_CREDENTIAL_LIST; -use crate::{ - com_provider::{ - parse_credential_list, WebAuthnPluginOperationRequest, WebAuthnPluginOperationResponse, - }, - ipc2::PasskeyAssertionWithoutUserInterfaceRequest, -}; +use crate::ipc2::PasskeyAssertionWithoutUserInterfaceRequest; use crate::{ ipc2::{ PasskeyAssertionRequest, PasskeyAssertionResponse, Position, TimedCallback, @@ -23,105 +10,69 @@ use crate::{ win_webauthn::{ErrorKind, HwndExt, PluginGetAssertionRequest, WinWebAuthnError}, }; -// Windows API types for WebAuthn (from webauthn.h.sample) -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST { - pub dwVersion: u32, - pub pwszRpId: *const u16, // PCWSTR - pub cbRpId: u32, - pub pbRpId: *const u8, - pub cbClientDataHash: u32, - pub pbClientDataHash: *const u8, - pub CredentialList: WEBAUTHN_CREDENTIAL_LIST, - pub cbCborExtensionsMap: u32, - pub pbCborExtensionsMap: *const u8, - pub pAuthenticatorOptions: *const crate::webauthn::WebAuthnCtapCborAuthenticatorOptions, - // Add other fields as needed... -} +pub fn get_assertion( + ipc_client: &WindowsProviderClient, + request: PluginGetAssertionRequest, +) -> Result, Box> { + // Extract RP information + let rp_id = request.rp_id().to_string(); -pub type PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST = *mut WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST; + // Extract client data hash + let client_data_hash = request.client_data_hash().to_vec(); -// Windows API function signatures for decoding get assertion requests -type WebAuthNDecodeGetAssertionRequestFn = unsafe extern "stdcall" fn( - cbEncoded: u32, - pbEncoded: *const u8, - ppGetAssertionRequest: *mut PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST, -) -> HRESULT; + // Extract user verification requirement from authenticator options + let user_verification = match request.authenticator_options().user_verification() { + Some(true) => UserVerification::Required, + Some(false) => UserVerification::Discouraged, + None => UserVerification::Preferred, + }; -type WebAuthNFreeDecodedGetAssertionRequestFn = - unsafe extern "stdcall" fn(pGetAssertionRequest: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST); + // Extract allowed credentials from credential list + let allowed_credential_ids: Vec> = request + .credential_list() + .iter() + .filter_map(|cred| cred.credential_id()) + .map(|id| id.to_vec()) + .collect(); -// RAII wrapper for decoded get assertion request -pub struct DecodedGetAssertionRequest { - ptr: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST, - free_fn: Option, -} + let transaction_id = request.transaction_id.to_u128().to_le_bytes().to_vec(); + let client_pos = request + .window_handle + .center_position() + .unwrap_or((640, 480)); -impl DecodedGetAssertionRequest { - fn new( - ptr: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST, - free_fn: Option, - ) -> Self { - Self { ptr, free_fn } - } - - pub fn as_ref(&self) -> &WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST { - unsafe { &*self.ptr } - } -} - -impl Drop for DecodedGetAssertionRequest { - fn drop(&mut self) { - if !self.ptr.is_null() { - if let Some(free_fn) = self.free_fn { - tracing::debug!("Freeing decoded get assertion request"); - unsafe { - free_fn(self.ptr); - } - } - } - } -} - -// Function to decode get assertion request using Windows API -unsafe fn decode_get_assertion_request( - encoded_request: &[u8], -) -> Result { - tracing::debug!("Attempting to decode get assertion request using Windows API"); - - // Load the Windows WebAuthn API function - let decode_fn: Option = - delay_load(s!("webauthn.dll"), s!("WebAuthNDecodeGetAssertionRequest")); - - let decode_fn = - decode_fn.ok_or("Failed to load WebAuthNDecodeGetAssertionRequest from webauthn.dll")?; - - // Load the free function - let free_fn: Option = delay_load( - s!("webauthn.dll"), - s!("WebAuthNFreeDecodedGetAssertionRequest"), + tracing::debug!( + "Get assertion request - RP: {}, Allowed credentials: {:?}", + rp_id, + allowed_credential_ids ); - let mut pp_get_assertion_request: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST = ptr::null_mut(); + // Send assertion request + let assertion_request = PasskeyAssertionRequest { + rp_id, + client_data_hash, + allowed_credentials: allowed_credential_ids, + user_verification, + window_xy: Position { + x: client_pos.0, + y: client_pos.1, + }, + context: transaction_id, + }; + let passkey_response = send_assertion_request(ipc_client, assertion_request) + .map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?; + tracing::debug!("Assertion response received: {:?}", passkey_response); - let result = decode_fn( - encoded_request.len() as u32, - encoded_request.as_ptr(), - &mut pp_get_assertion_request, - ); + // Create proper WebAuthn response from passkey_response + tracing::debug!("Creating WebAuthn get assertion response"); - if result.is_err() || pp_get_assertion_request.is_null() { - return Err(format!( - "WebAuthNDecodeGetAssertionRequest failed with HRESULT: {}", - result.0 - )); - } - - Ok(DecodedGetAssertionRequest::new( - pp_get_assertion_request, - free_fn, - )) + let response = create_get_assertion_response( + passkey_response.credential_id, + passkey_response.authenticator_data, + passkey_response.signature, + passkey_response.user_handle, + )?; + Ok(response) } /// Helper for assertion requests @@ -237,241 +188,9 @@ fn create_get_assertion_response( Ok(cbor_data) } -unsafe fn write_response( - cbor_data: &[u8], -) -> Result<*mut WebAuthnPluginOperationResponse, HRESULT> { - // Allocate memory for the response data - let response_len = cbor_data.len(); - let layout = Layout::from_size_align(response_len, 1).map_err(|_| HRESULT(-1))?; - let response_ptr = alloc(layout); - if response_ptr.is_null() { - return Err(HRESULT(-1)); - } - - // Copy response data - ptr::copy_nonoverlapping(cbor_data.as_ptr(), response_ptr, response_len); - - // Allocate memory for the response structure - let response_layout = Layout::new::(); - let operation_response_ptr = alloc(response_layout) as *mut WebAuthnPluginOperationResponse; - if operation_response_ptr.is_null() { - return Err(HRESULT(-1)); - } - - // Initialize the response - ptr::write( - operation_response_ptr, - WebAuthnPluginOperationResponse { - encoded_response_byte_count: response_len as u32, - encoded_response_pointer: response_ptr, - }, - ); - - Ok(operation_response_ptr) -} - -pub fn get_assertion( - ipc_client: &WindowsProviderClient, - request: PluginGetAssertionRequest, -) -> Result, Box> { - // Extract RP information - let rp_id = request.rp_id().to_string(); - - // Extract client data hash - let client_data_hash = request.client_data_hash().to_vec(); - - // Extract user verification requirement from authenticator options - let user_verification = match request.authenticator_options().user_verification() { - Some(true) => UserVerification::Required, - Some(false) => UserVerification::Discouraged, - None => UserVerification::Preferred, - }; - - // Extract allowed credentials from credential list - let allowed_credential_ids: Vec> = request - .credential_list() - .iter() - .filter_map(|cred| cred.credential_id()) - .map(|id| id.to_vec()) - .collect(); - - let transaction_id = request.transaction_id.to_u128().to_le_bytes().to_vec(); - let client_pos = request - .window_handle - .center_position() - .unwrap_or((640, 480)); - - tracing::debug!( - "Get assertion request - RP: {}, Allowed credentials: {:?}", - rp_id, - allowed_credential_ids - ); - - // Send assertion request - let assertion_request = PasskeyAssertionRequest { - rp_id, - client_data_hash, - allowed_credentials: allowed_credential_ids, - user_verification, - window_xy: Position { - x: client_pos.0, - y: client_pos.1, - }, - context: transaction_id, - }; - let passkey_response = send_assertion_request(ipc_client, assertion_request) - .map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?; - tracing::debug!("Assertion response received: {:?}", passkey_response); - - // Create proper WebAuthn response from passkey_response - tracing::debug!("Creating WebAuthn get assertion response"); - - let response = create_get_assertion_response( - passkey_response.credential_id, - passkey_response.authenticator_data, - passkey_response.signature, - passkey_response.user_handle, - )?; - Ok(response) -} - -/// Implementation of PluginGetAssertion moved from com_provider.rs -pub unsafe fn plugin_get_assertion( - ipc_client: &WindowsProviderClient, - request: *const WebAuthnPluginOperationRequest, - response: *mut WebAuthnPluginOperationResponse, -) -> Result<(), HRESULT> { - tracing::debug!("PluginGetAssertion() called"); - - // Validate input parameters - if request.is_null() || response.is_null() { - tracing::debug!("Invalid parameters passed to PluginGetAssertion"); - return Err(HRESULT(-1)); - } - - let req = &*request; - let transaction_id = format!("{:?}", req.transaction_id); - let coords = req.window_coordinates().unwrap_or((400, 400)); - - tracing::debug!("Get assertion request - Transaction: {}", transaction_id); - - if req.encoded_request_byte_count == 0 || req.encoded_request_pointer.is_null() { - tracing::error!("No encoded request data provided"); - return Err(HRESULT(-1)); - } - - let encoded_request_slice = std::slice::from_raw_parts( - req.encoded_request_pointer, - req.encoded_request_byte_count as usize, - ); - - // Try to decode the request using Windows API - let decoded_wrapper = decode_get_assertion_request(encoded_request_slice).map_err(|err| { - tracing::debug!("Failed to decode get assertion request: {err}"); - HRESULT(-1) - })?; - let decoded_request = decoded_wrapper.as_ref(); - tracing::debug!("Successfully decoded get assertion request using Windows API"); - - // Extract RP information - let rpid = if decoded_request.pwszRpId.is_null() { - tracing::error!("RP ID is null"); - return Err(HRESULT(-1)); - } else { - match wstr_to_string(decoded_request.pwszRpId) { - Ok(id) => id, - Err(e) => { - tracing::error!("Failed to decode RP ID: {}", e); - return Err(HRESULT(-1)); - } - } - }; - - // Extract client data hash - let client_data_hash = - if decoded_request.cbClientDataHash == 0 || decoded_request.pbClientDataHash.is_null() { - tracing::error!("Client data hash is required for assertion"); - return Err(HRESULT(-1)); - } else { - let hash_slice = std::slice::from_raw_parts( - decoded_request.pbClientDataHash, - decoded_request.cbClientDataHash as usize, - ); - hash_slice.to_vec() - }; - - // Extract user verification requirement from authenticator options - let user_verification = if !decoded_request.pAuthenticatorOptions.is_null() { - let auth_options = &*decoded_request.pAuthenticatorOptions; - match auth_options.user_verification { - 1 => UserVerification::Required, - -1 => UserVerification::Discouraged, - 0 | _ => UserVerification::Preferred, // Default or undefined - } - } else { - UserVerification::Preferred // Default or undefined - }; - - // Extract allowed credentials from credential list - let allowed_credentials = parse_credential_list(&decoded_request.CredentialList); - - // Create Windows assertion request - let transaction_id = req.transaction_id.to_u128().to_le_bytes().to_vec(); - let assertion_request = PasskeyAssertionRequest { - rp_id: rpid.clone(), - client_data_hash, - allowed_credentials: allowed_credentials.clone(), - user_verification, - window_xy: Position { - x: coords.0, - y: coords.1, - }, - context: transaction_id, - }; - - tracing::debug!( - "Get assertion request - RP: {}, Allowed credentials: {:?}", - rpid, - allowed_credentials - ); - - // Send assertion request - let passkey_response = - send_assertion_request(ipc_client, assertion_request).map_err(|err| { - tracing::error!("Assertion request failed: {err}"); - HRESULT(-1) - })?; - tracing::debug!("Assertion response received: {:?}", passkey_response); - - // Create proper WebAuthn response from passkey_response - tracing::debug!("Creating WebAuthn get assertion response"); - - let webauthn_response = create_get_assertion_response( - passkey_response.credential_id, - passkey_response.authenticator_data, - passkey_response.signature, - passkey_response.user_handle, - ) - .map_err(|err| { - tracing::error!("Failed to encode WebAuthn assertion response as CBOR: {err}"); - HRESULT(-1) - }) - .and_then(|cbor| write_response(&cbor)) - .map_err(|err| { - tracing::error!("Failed to create WebAuthn assertion response: {err}"); - HRESULT(-1) - })?; - tracing::debug!("Successfully created WebAuthn assertion response"); - (*response).encoded_response_byte_count = (*webauthn_response).encoded_response_byte_count; - (*response).encoded_response_pointer = (*webauthn_response).encoded_response_pointer; - Ok(()) -} - #[cfg(test)] mod tests { - use std::ptr::slice_from_raw_parts; - - use super::{create_get_assertion_response, write_response}; + use super::create_get_assertion_response; #[test] fn test_create_native_assertion_response() { @@ -479,21 +198,14 @@ mod tests { let authenticator_data = vec![5, 6, 7, 8]; let signature = vec![9, 10, 11, 12]; let user_handle = vec![13, 14, 15, 16]; - let slice = unsafe { - let cbor = create_get_assertion_response( - credential_id, - authenticator_data, - signature, - user_handle, - ) - .unwrap(); - let response = *write_response(&cbor).unwrap(); - &*slice_from_raw_parts( - response.encoded_response_pointer, - response.encoded_response_byte_count as usize, - ) - }; + let cbor = create_get_assertion_response( + credential_id, + authenticator_data, + signature, + user_handle, + ) + .unwrap(); // CTAP2_OK, Map(5 elements) - assert_eq!([0x00, 0xa5], slice[..2]); + assert_eq!([0x00, 0xa5], cbor[..2]); } } diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/com_provider.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/com_provider.rs index 6ca9a844ee4..9c3e3acb955 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/com_provider.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/com_provider.rs @@ -1,14 +1,6 @@ -use std::sync::Arc; -use std::time::Duration; - -use windows::core::{implement, interface, IInspectable, IUnknown, Interface, HRESULT}; -use windows::Win32::Foundation::{RECT, S_OK}; -use windows::Win32::System::Com::*; +use windows::Win32::Foundation::RECT; use windows::Win32::UI::WindowsAndMessaging::GetWindowRect; -use crate::assert::plugin_get_assertion; -use crate::ipc2::{TimedCallback, WindowsProviderClient}; -use crate::make_credential::plugin_make_credential; use crate::webauthn::WEBAUTHN_CREDENTIAL_LIST; /// Plugin request type enum as defined in the IDL @@ -18,6 +10,7 @@ pub enum WebAuthnPluginRequestType { CTAP2_CBOR = 0x01, } +/* /// Plugin lock status enum as defined in the IDL #[repr(u32)] #[derive(Debug, Copy, Clone)] @@ -25,6 +18,7 @@ pub enum PluginLockStatus { PluginLocked = 0, PluginUnlocked = 1, } +*/ /// Used when creating and asserting credentials. /// Header File Name: _WEBAUTHN_PLUGIN_OPERATION_REQUEST @@ -65,34 +59,6 @@ pub struct WebAuthnPluginOperationResponse { pub encoded_response_pointer: *mut u8, } -/// Used to cancel an operation. -/// Header File Name: _WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST -/// Header File Usage: CancelOperation() -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct WebAuthnPluginCancelOperationRequest { - pub transaction_id: windows::core::GUID, - pub request_signature_byte_count: u32, - pub request_signature_pointer: *mut u8, -} - -// Stable IPluginAuthenticator interface -#[interface("d26bcf6f-b54c-43ff-9f06-d5bf148625f7")] -pub unsafe trait IPluginAuthenticator: windows::core::IUnknown { - fn MakeCredential( - &self, - request: *const WebAuthnPluginOperationRequest, - response: *mut WebAuthnPluginOperationResponse, - ) -> HRESULT; - fn GetAssertion( - &self, - request: *const WebAuthnPluginOperationRequest, - response: *mut WebAuthnPluginOperationResponse, - ) -> HRESULT; - fn CancelOperation(&self, request: *const WebAuthnPluginCancelOperationRequest) -> HRESULT; - fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT; -} - pub unsafe fn parse_credential_list(credential_list: &WEBAUTHN_CREDENTIAL_LIST) -> Vec> { let mut allowed_credentials = Vec::new(); @@ -144,112 +110,3 @@ pub unsafe fn parse_credential_list(credential_list: &WEBAUTHN_CREDENTIAL_LIST) ); allowed_credentials } - -#[implement(IPluginAuthenticator)] -pub struct PluginAuthenticatorComObject { - client: WindowsProviderClient, -} - -#[implement(IClassFactory)] -pub struct Factory; - -impl IPluginAuthenticator_Impl for PluginAuthenticatorComObject_Impl { - unsafe fn MakeCredential( - &self, - request: *const WebAuthnPluginOperationRequest, - response: *mut WebAuthnPluginOperationResponse, - ) -> HRESULT { - tracing::debug!("MakeCredential() called"); - tracing::debug!("version2"); - // Convert to legacy format for internal processing - if request.is_null() || response.is_null() { - tracing::debug!("MakeCredential: Invalid request or response pointers passed"); - return HRESULT(-1); - } - - let response = match plugin_make_credential(&self.client, request, response) { - Ok(()) => S_OK, - Err(err) => err, - }; - tracing::debug!("MakeCredential() completed"); - response - } - - unsafe fn GetAssertion( - &self, - request: *const WebAuthnPluginOperationRequest, - response: *mut WebAuthnPluginOperationResponse, - ) -> HRESULT { - tracing::debug!("GetAssertion() called"); - if request.is_null() || response.is_null() { - return HRESULT(-1); - } - - match plugin_get_assertion(&self.client, request, response) { - Ok(()) => S_OK, - Err(err) => err, - } - } - - unsafe fn CancelOperation( - &self, - _request: *const WebAuthnPluginCancelOperationRequest, - ) -> HRESULT { - tracing::debug!("CancelOperation() called"); - HRESULT(0) - } - - unsafe fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT { - tracing::debug!( - "GetLockStatus() called ", - std::process::id(), - std::thread::current().id() - ); - if lock_status.is_null() { - return HRESULT(-2147024809); // E_INVALIDARG - } - *lock_status = PluginLockStatus::PluginLocked; - let callback = Arc::new(TimedCallback::new()); - self.client.get_lock_status(callback.clone()); - match callback.wait_for_response(Duration::from_secs(3)) { - Ok(Ok(response)) => { - let status = if response.is_unlocked { - PluginLockStatus::PluginUnlocked - } else { - PluginLockStatus::PluginLocked - }; - tracing::debug!("GetLockStatus() received {status:?}"); - *lock_status = status; - HRESULT(0) - } - Ok(Err(err)) => { - tracing::error!("GetLockStatus() call failed: {err}"); - HRESULT(-1) - } - Err(_) => { - tracing::error!("GetLockStatus() call timed out"); - HRESULT(-1) - } - } - } -} - -impl IClassFactory_Impl for Factory_Impl { - fn CreateInstance( - &self, - _outer: windows::core::Ref, - iid: *const windows::core::GUID, - object: *mut *mut core::ffi::c_void, - ) -> windows::core::Result<()> { - tracing::debug!("Creating COM server instance."); - tracing::debug!("Trying to connect to Bitwarden IPC"); - let client = WindowsProviderClient::connect(); - tracing::debug!("Connected to Bitwarden IPC"); - let unknown: IInspectable = PluginAuthenticatorComObject { client }.into(); // TODO: IUnknown ? - unsafe { unknown.query(iid, object).ok() } - } - - fn LockServer(&self, _lock: windows::core::BOOL) -> windows::core::Result<()> { - Ok(()) - } -} diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/win_webauthn/types.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/win_webauthn/types.rs index 97f3398ae6c..ee2de74dd13 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/win_webauthn/types.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/win_webauthn/types.rs @@ -1,6 +1,6 @@ //! Types and functions defined in the Windows WebAuthn API. -use std::{collections::HashSet, ptr::NonNull}; +use std::{collections::HashSet, mem::MaybeUninit, ptr::NonNull}; use base64::{engine::general_purpose::STANDARD, Engine as _}; use ciborium::Value; @@ -10,7 +10,9 @@ use windows::{ }; use windows_core::{s, PCWSTR}; -use crate::win_webauthn::{util::WindowsString, Clsid, ErrorKind, WinWebAuthnError}; +use crate::win_webauthn::{ + com::ComBuffer, util::WindowsString, Clsid, ErrorKind, WinWebAuthnError, +}; macro_rules! webauthn_call { ($symbol:literal as fn $fn_name:ident($($arg:ident: $arg_type:ty),+) -> $result_type:ty) => (