diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs index 70d8fb16816..064d34bd719 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs @@ -1,460 +1,463 @@ -//! Functions for interacting with Windows COM. -#![allow(non_snake_case)] -#![allow(non_camel_case_types)] - -use std::{ - alloc, - mem::MaybeUninit, - ptr::{self, NonNull}, - sync::{Arc, OnceLock}, -}; - -use windows::{ - core::{implement, interface, ComObjectInterface, IUnknown, GUID, HRESULT}, - Win32::{ - Foundation::{E_FAIL, E_INVALIDARG, S_OK}, - System::Com::*, - }, -}; -use windows_core::{IInspectable, Interface}; - -use super::types::{ - PluginLockStatus, WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST, WEBAUTHN_PLUGIN_OPERATION_REQUEST, - WEBAUTHN_PLUGIN_OPERATION_RESPONSE, -}; - -use super::PluginAuthenticator; -use crate::{ - plugin::{crypto, PluginGetAssertionRequest, PluginMakeCredentialRequest}, - ErrorKind, WinWebAuthnError, -}; - -static HANDLER: OnceLock<(GUID, Arc)> = OnceLock::new(); - -#[implement(IClassFactory)] -pub struct Factory; - -impl IClassFactory_Impl for Factory_Impl { - fn CreateInstance( - &self, - _outer: windows::core::Ref, - iid: *const windows::core::GUID, - object: *mut *mut core::ffi::c_void, - ) -> windows::core::Result<()> { - let (clsid, handler) = match HANDLER.get() { - Some(state) => state, - None => { - tracing::error!("Cannot create COM class object instance because the handler is not initialized. register_server() must be called before starting the COM server."); - return Err(E_FAIL.into()); - } - }.clone(); - let unknown: IInspectable = PluginAuthenticatorComObject { clsid, handler }.into(); - unsafe { unknown.query(iid, object).ok() } - } - - fn LockServer(&self, _lock: windows::core::BOOL) -> windows::core::Result<()> { - // TODO: Implement lock server - Ok(()) - } -} - -// IPluginAuthenticator interface -#[interface("d26bcf6f-b54c-43ff-9f06-d5bf148625f7")] -pub unsafe trait IPluginAuthenticator: windows::core::IUnknown { - fn MakeCredential( - &self, - request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, - response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, - ) -> HRESULT; - fn GetAssertion( - &self, - request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, - response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, - ) -> HRESULT; - fn CancelOperation(&self, request: *const WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST) -> HRESULT; - fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT; -} - -#[implement(IPluginAuthenticator)] -struct PluginAuthenticatorComObject { - clsid: GUID, - handler: Arc, -} - -impl IPluginAuthenticator_Impl for PluginAuthenticatorComObject_Impl { - unsafe fn MakeCredential( - &self, - request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, - response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, - ) -> HRESULT { - tracing::debug!("MakeCredential called"); - let response = match NonNull::new(response) { - Some(p) => p, - None => { - tracing::warn!( - "MakeCredential called with null response pointer from Windows. Aborting request." - ); - return E_INVALIDARG; - } - }; - let op_request_ptr = match NonNull::new(request.cast_mut()) { - Some(p) => p, - None => { - tracing::warn!( - "MakeCredential called with null request pointer from Windows. Aborting request." - ); - return E_INVALIDARG; - } - }; - - if let Err(err) = verify_operation_request(op_request_ptr.as_ref(), &self.clsid) { - tracing::error!("Failed to verify request signature: {err}"); - return E_INVALIDARG; - } - - // SAFETY: we received the pointer from Windows, so we trust that the values are set properly. - let registration_request = match PluginMakeCredentialRequest::try_from_ptr(op_request_ptr) { - Ok(r) => r, - Err(err) => { - tracing::error!("Could not deserialize MakeCredential request: {err}"); - return E_FAIL; - } - }; - match self.handler.make_credential(registration_request) { - Ok(registration_response) => { - // SAFETY: response pointer was given to us by Windows, so we assume it's valid. - match write_operation_response(®istration_response, response) { - Ok(()) => { - tracing::debug!("MakeCredential completed successfully"); - S_OK - } - Err(err) => { - tracing::error!( - "Failed to write MakeCredential response to Windows: {err}" - ); - return E_FAIL; - } - } - } - Err(err) => { - tracing::error!("MakeCredential failed: {err}"); - E_FAIL - } - } - } - - unsafe fn GetAssertion( - &self, - request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, - response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, - ) -> HRESULT { - tracing::debug!("GetAssertion called"); - let response = match NonNull::new(response) { - Some(p) => p, - None => { - tracing::warn!( - "GetAssertion called with null response pointer from Windows. Aborting request." - ); - return E_INVALIDARG; - } - }; - let op_request_ptr = match NonNull::new(request.cast_mut()) { - Some(p) => p, - None => { - tracing::warn!( - "GetAssertion called with null request pointer from Windows. Aborting request." - ); - return E_INVALIDARG; - } - }; - - if let Err(err) = verify_operation_request(op_request_ptr.as_ref(), &self.clsid) { - tracing::error!("Failed to verify request signature: {err}"); - return E_INVALIDARG; - } - - let assertion_request = match PluginGetAssertionRequest::try_from_ptr(op_request_ptr) { - Ok(assertion_request) => assertion_request, - Err(err) => { - tracing::error!("Could not deserialize GetAssertion request: {err}"); - return E_FAIL; - } - }; - match self.handler.get_assertion(assertion_request) { - Ok(assertion_response) => { - // SAFETY: response pointer was given to us by Windows, so we assume it's valid. - match write_operation_response(&assertion_response, response) { - Ok(()) => { - tracing::debug!("GetAssertion completed successfully"); - S_OK - } - Err(err) => { - tracing::error!("Failed to write GetCredential response to Windows: {err}"); - return E_FAIL; - } - } - } - Err(err) => { - tracing::error!("GetAssertion failed: {err}"); - E_FAIL - } - } - } - - unsafe fn CancelOperation( - &self, - request: *const WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST, - ) -> HRESULT { - tracing::debug!("CancelOperation called"); - let request = match NonNull::new(request as *mut WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST) { - Some(request) => request, - None => { - tracing::warn!("Received null CancelOperation request"); - return E_INVALIDARG; - } - }; - - match self.handler.cancel_operation(request.into()) { - Ok(()) => { - tracing::error!("CancelOperation completed successfully"); - S_OK - } - Err(err) => { - tracing::error!("CancelOperation failed: {err}"); - E_FAIL - } - } - } - - unsafe fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT { - tracing::debug!( - "GetLockStatus() called ", - std::process::id(), - std::thread::current().id() - ); - if lock_status.is_null() { - return HRESULT(-2147024809); // E_INVALIDARG - } - - match self.handler.lock_status() { - Ok(status) => { - tracing::debug!("GetLockStatus received {status:?}"); - *lock_status = status; - S_OK - } - Err(err) => { - tracing::error!("GetLockStatus failed: {err}"); - E_FAIL - } - } - } -} - -/// Copies data as COM-allocated buffer and writes to response pointer. -/// -/// Safety constraints: [response] must point to a valid -/// WEBAUTHN_PLUGIN_OPERATION_RESPONSE struct. -unsafe fn write_operation_response( - data: &[u8], - response: NonNull, -) -> Result<(), WinWebAuthnError> { - let len = match data.len().try_into() { - Ok(len) => len, - Err(err) => { - return Err(WinWebAuthnError::with_cause( - ErrorKind::Serialization, - "Response is too long to return to OS", - err, - )); - } - }; - let buf = data.to_com_buffer(); - - response.write(WEBAUTHN_PLUGIN_OPERATION_RESPONSE { - cbEncodedResponse: len, - pbEncodedResponse: buf.leak(), - }); - Ok(()) -} - -/// Registers the plugin authenticator COM library with Windows. -pub(super) fn register_server(clsid: &GUID, handler: T) -> Result<(), WinWebAuthnError> -where - T: PluginAuthenticator + Send + Sync + 'static, -{ - // Store the handler as a static so it can be initialized - HANDLER.set((*clsid, Arc::new(handler))).map_err(|_| { - WinWebAuthnError::new(ErrorKind::WindowsInternal, "Handler already initialized") - })?; - - static FACTORY: windows::core::StaticComObject = Factory.into_static(); - unsafe { - CoRegisterClassObject( - ptr::from_ref(clsid), - FACTORY.as_interface_ref(), - CLSCTX_LOCAL_SERVER, - REGCLS_MULTIPLEUSE, - ) - } - .map_err(|err| { - WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Couldn't register the COM library with Windows", - err, - ) - })?; - Ok(()) -} - -/// Initializes the COM library for use on the calling thread, -/// and registers + sets the security values. -pub(super) fn initialize() -> std::result::Result<(), WinWebAuthnError> { - let result = unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED) }; - - if result.is_err() { - return Err(WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Could not initialize the COM library", - windows::core::Error::from_hresult(result), - )); - } - - unsafe { - CoInitializeSecurity( - None, - -1, - None, - None, - RPC_C_AUTHN_LEVEL_DEFAULT, - RPC_C_IMP_LEVEL_IMPERSONATE, - None, - EOAC_NONE, - None, - ) - } - .map_err(|err| { - WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Could not initialize COM security", - err, - ) - }) -} - -pub(super) fn uninitialize() -> std::result::Result<(), WinWebAuthnError> { - unsafe { CoUninitialize() }; - Ok(()) -} - -#[repr(transparent)] -pub(super) struct ComBuffer(NonNull>); - -impl ComBuffer { - /// Returns an COM-allocated buffer of `size`. - fn alloc(size: usize, for_slice: bool) -> Self { - #[expect(clippy::as_conversions)] - { - assert!(size <= isize::MAX as usize, "requested bad object size"); - } - - // SAFETY: Any size is valid to pass to Windows, even `0`. - let ptr = NonNull::new(unsafe { CoTaskMemAlloc(size) }).unwrap_or_else(|| { - // XXX: This doesn't have to be correct, just close enough for an OK OOM error. - let layout = alloc::Layout::from_size_align(size, align_of::()).unwrap(); - alloc::handle_alloc_error(layout) - }); - - if for_slice { - // Ininitialize the buffer so it can later be treated as `&mut [u8]`. - // SAFETY: The pointer is valid and we are using a valid value for a byte-wise allocation. - unsafe { ptr.write_bytes(0, size) }; - } - - Self(ptr.cast()) - } - - pub fn leak(self) -> *mut T { - self.0.cast().as_ptr() - } -} - -pub(super) trait ComBufferExt { - fn to_com_buffer(&self) -> ComBuffer; -} - -impl ComBufferExt for Vec { - fn to_com_buffer(&self) -> ComBuffer { - ComBuffer::from(&self) - } -} - -impl ComBufferExt for &[u8] { - fn to_com_buffer(&self) -> ComBuffer { - ComBuffer::from(self) - } -} - -impl ComBufferExt for Vec { - fn to_com_buffer(&self) -> ComBuffer { - let buffer: Vec = self.into_iter().flat_map(|x| x.to_le_bytes()).collect(); - ComBuffer::from(&buffer) - } -} - -impl ComBufferExt for &[u16] { - fn to_com_buffer(&self) -> ComBuffer { - let buffer: Vec = self - .as_ref() - .into_iter() - .flat_map(|x| x.to_le_bytes()) - .collect(); - ComBuffer::from(&buffer) - } -} - -impl> From for ComBuffer { - fn from(value: T) -> Self { - let buffer: Vec = value - .as_ref() - .into_iter() - .flat_map(|x| x.to_le_bytes()) - .collect(); - let len = buffer.len(); - let com_buffer = Self::alloc(len, true); - // SAFETY: `ptr` points to a valid allocation that `len` matches, and we made sure - // the bytes were initialized. Additionally, bytes have no alignment requirements. - unsafe { - NonNull::slice_from_raw_parts(com_buffer.0.cast::(), len) - .as_mut() - .copy_from_slice(&buffer); - } - com_buffer - } -} - -unsafe fn verify_operation_request( - request: &WEBAUTHN_PLUGIN_OPERATION_REQUEST, - clsid: &GUID, -) -> Result<(), WinWebAuthnError> { - // Verify request - tracing::debug!("Verifying request"); - let request_data = - std::slice::from_raw_parts(request.pbEncodedRequest, request.cbEncodedRequest as usize); - let request_hash = crypto::hash_sha256(request_data).map_err(|err| { - WinWebAuthnError::with_cause(ErrorKind::WindowsInternal, "failed to hash request", err) - })?; - let signature = std::slice::from_raw_parts( - request.pbRequestSignature, - request.cbRequestSignature as usize, - ); - tracing::debug!("Retrieving signing key"); - let op_pub_key = crypto::get_operation_signing_public_key(clsid).map_err(|err| { - WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Failed to get signing key for operation", - err, - ) - })?; - tracing::debug!("Verifying signature"); - op_pub_key.verify_signature(&request_hash, signature) -} +//! Functions for interacting with Windows COM. +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] + +use std::{ + alloc, + mem::MaybeUninit, + ptr::{self, NonNull}, + sync::{Arc, OnceLock}, +}; + +use windows::{ + core::{implement, interface, ComObjectInterface, IUnknown, GUID, HRESULT}, + Win32::{ + Foundation::{E_FAIL, E_INVALIDARG, S_OK}, + System::Com::*, + }, +}; +use windows_core::{IInspectable, Interface}; + +use super::{ + types::{ + PluginLockStatus, WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST, + WEBAUTHN_PLUGIN_OPERATION_REQUEST, WEBAUTHN_PLUGIN_OPERATION_RESPONSE, + }, + PluginAuthenticator, +}; +use crate::{ + plugin::{crypto, PluginGetAssertionRequest, PluginMakeCredentialRequest}, + ErrorKind, WinWebAuthnError, +}; + +static HANDLER: OnceLock<(GUID, Arc)> = OnceLock::new(); + +#[implement(IClassFactory)] +pub struct Factory; + +impl IClassFactory_Impl for Factory_Impl { + fn CreateInstance( + &self, + _outer: windows::core::Ref, + iid: *const windows::core::GUID, + object: *mut *mut core::ffi::c_void, + ) -> windows::core::Result<()> { + let (clsid, handler) = match HANDLER.get() { + Some(state) => state, + None => { + tracing::error!("Cannot create COM class object instance because the handler is not initialized. register_server() must be called before starting the COM server."); + return Err(E_FAIL.into()); + } + }.clone(); + let unknown: IInspectable = PluginAuthenticatorComObject { clsid, handler }.into(); + unsafe { unknown.query(iid, object).ok() } + } + + fn LockServer(&self, _lock: windows::core::BOOL) -> windows::core::Result<()> { + // TODO: Implement lock server + Ok(()) + } +} + +// IPluginAuthenticator interface +#[interface("d26bcf6f-b54c-43ff-9f06-d5bf148625f7")] +pub unsafe trait IPluginAuthenticator: windows::core::IUnknown { + fn MakeCredential( + &self, + request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, + response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, + ) -> HRESULT; + fn GetAssertion( + &self, + request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, + response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, + ) -> HRESULT; + fn CancelOperation(&self, request: *const WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST) -> HRESULT; + fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT; +} + +#[implement(IPluginAuthenticator)] +struct PluginAuthenticatorComObject { + clsid: GUID, + handler: Arc, +} + +impl IPluginAuthenticator_Impl for PluginAuthenticatorComObject_Impl { + unsafe fn MakeCredential( + &self, + request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, + response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, + ) -> HRESULT { + tracing::debug!("MakeCredential called"); + let response = match NonNull::new(response) { + Some(p) => p, + None => { + tracing::warn!( + "MakeCredential called with null response pointer from Windows. Aborting request." + ); + return E_INVALIDARG; + } + }; + let op_request_ptr = match NonNull::new(request.cast_mut()) { + Some(p) => p, + None => { + tracing::warn!( + "MakeCredential called with null request pointer from Windows. Aborting request." + ); + return E_INVALIDARG; + } + }; + + if let Err(err) = verify_operation_request(op_request_ptr.as_ref(), &self.clsid) { + tracing::error!("Failed to verify request signature: {err}"); + return E_INVALIDARG; + } + + // SAFETY: we received the pointer from Windows, so we trust that the values are set + // properly. + let registration_request = match PluginMakeCredentialRequest::try_from_ptr(op_request_ptr) { + Ok(r) => r, + Err(err) => { + tracing::error!("Could not deserialize MakeCredential request: {err}"); + return E_FAIL; + } + }; + match self.handler.make_credential(registration_request) { + Ok(registration_response) => { + // SAFETY: response pointer was given to us by Windows, so we assume it's valid. + match write_operation_response(®istration_response, response) { + Ok(()) => { + tracing::debug!("MakeCredential completed successfully"); + S_OK + } + Err(err) => { + tracing::error!( + "Failed to write MakeCredential response to Windows: {err}" + ); + return E_FAIL; + } + } + } + Err(err) => { + tracing::error!("MakeCredential failed: {err}"); + E_FAIL + } + } + } + + unsafe fn GetAssertion( + &self, + request: *const WEBAUTHN_PLUGIN_OPERATION_REQUEST, + response: *mut WEBAUTHN_PLUGIN_OPERATION_RESPONSE, + ) -> HRESULT { + tracing::debug!("GetAssertion called"); + let response = match NonNull::new(response) { + Some(p) => p, + None => { + tracing::warn!( + "GetAssertion called with null response pointer from Windows. Aborting request." + ); + return E_INVALIDARG; + } + }; + let op_request_ptr = match NonNull::new(request.cast_mut()) { + Some(p) => p, + None => { + tracing::warn!( + "GetAssertion called with null request pointer from Windows. Aborting request." + ); + return E_INVALIDARG; + } + }; + + if let Err(err) = verify_operation_request(op_request_ptr.as_ref(), &self.clsid) { + tracing::error!("Failed to verify request signature: {err}"); + return E_INVALIDARG; + } + + let assertion_request = match PluginGetAssertionRequest::try_from_ptr(op_request_ptr) { + Ok(assertion_request) => assertion_request, + Err(err) => { + tracing::error!("Could not deserialize GetAssertion request: {err}"); + return E_FAIL; + } + }; + match self.handler.get_assertion(assertion_request) { + Ok(assertion_response) => { + // SAFETY: response pointer was given to us by Windows, so we assume it's valid. + match write_operation_response(&assertion_response, response) { + Ok(()) => { + tracing::debug!("GetAssertion completed successfully"); + S_OK + } + Err(err) => { + tracing::error!("Failed to write GetCredential response to Windows: {err}"); + return E_FAIL; + } + } + } + Err(err) => { + tracing::error!("GetAssertion failed: {err}"); + E_FAIL + } + } + } + + unsafe fn CancelOperation( + &self, + request: *const WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST, + ) -> HRESULT { + tracing::debug!("CancelOperation called"); + let request = match NonNull::new(request as *mut WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST) { + Some(request) => request, + None => { + tracing::warn!("Received null CancelOperation request"); + return E_INVALIDARG; + } + }; + + match self.handler.cancel_operation(request.into()) { + Ok(()) => { + tracing::error!("CancelOperation completed successfully"); + S_OK + } + Err(err) => { + tracing::error!("CancelOperation failed: {err}"); + E_FAIL + } + } + } + + unsafe fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT { + tracing::debug!( + "GetLockStatus() called ", + std::process::id(), + std::thread::current().id() + ); + if lock_status.is_null() { + return HRESULT(-2147024809); // E_INVALIDARG + } + + match self.handler.lock_status() { + Ok(status) => { + tracing::debug!("GetLockStatus received {status:?}"); + *lock_status = status; + S_OK + } + Err(err) => { + tracing::error!("GetLockStatus failed: {err}"); + E_FAIL + } + } + } +} + +/// Copies data as COM-allocated buffer and writes to response pointer. +/// +/// Safety constraints: [response] must point to a valid +/// WEBAUTHN_PLUGIN_OPERATION_RESPONSE struct. +unsafe fn write_operation_response( + data: &[u8], + response: NonNull, +) -> Result<(), WinWebAuthnError> { + let len = match data.len().try_into() { + Ok(len) => len, + Err(err) => { + return Err(WinWebAuthnError::with_cause( + ErrorKind::Serialization, + "Response is too long to return to OS", + err, + )); + } + }; + let buf = data.to_com_buffer(); + + response.write(WEBAUTHN_PLUGIN_OPERATION_RESPONSE { + cbEncodedResponse: len, + pbEncodedResponse: buf.leak(), + }); + Ok(()) +} + +/// Registers the plugin authenticator COM library with Windows. +pub(super) fn register_server(clsid: &GUID, handler: T) -> Result<(), WinWebAuthnError> +where + T: PluginAuthenticator + Send + Sync + 'static, +{ + // Store the handler as a static so it can be initialized + HANDLER.set((*clsid, Arc::new(handler))).map_err(|_| { + WinWebAuthnError::new(ErrorKind::WindowsInternal, "Handler already initialized") + })?; + + static FACTORY: windows::core::StaticComObject = Factory.into_static(); + unsafe { + CoRegisterClassObject( + ptr::from_ref(clsid), + FACTORY.as_interface_ref(), + CLSCTX_LOCAL_SERVER, + REGCLS_MULTIPLEUSE, + ) + } + .map_err(|err| { + WinWebAuthnError::with_cause( + ErrorKind::WindowsInternal, + "Couldn't register the COM library with Windows", + err, + ) + })?; + Ok(()) +} + +/// Initializes the COM library for use on the calling thread, +/// and registers + sets the security values. +pub(super) fn initialize() -> std::result::Result<(), WinWebAuthnError> { + let result = unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED) }; + + if result.is_err() { + return Err(WinWebAuthnError::with_cause( + ErrorKind::WindowsInternal, + "Could not initialize the COM library", + windows::core::Error::from_hresult(result), + )); + } + + unsafe { + CoInitializeSecurity( + None, + -1, + None, + None, + RPC_C_AUTHN_LEVEL_DEFAULT, + RPC_C_IMP_LEVEL_IMPERSONATE, + None, + EOAC_NONE, + None, + ) + } + .map_err(|err| { + WinWebAuthnError::with_cause( + ErrorKind::WindowsInternal, + "Could not initialize COM security", + err, + ) + }) +} + +pub(super) fn uninitialize() -> std::result::Result<(), WinWebAuthnError> { + unsafe { CoUninitialize() }; + Ok(()) +} + +#[repr(transparent)] +pub(super) struct ComBuffer(NonNull>); + +impl ComBuffer { + /// Returns an COM-allocated buffer of `size`. + fn alloc(size: usize, for_slice: bool) -> Self { + #[expect(clippy::as_conversions)] + { + assert!(size <= isize::MAX as usize, "requested bad object size"); + } + + // SAFETY: Any size is valid to pass to Windows, even `0`. + let ptr = NonNull::new(unsafe { CoTaskMemAlloc(size) }).unwrap_or_else(|| { + // XXX: This doesn't have to be correct, just close enough for an OK OOM error. + let layout = alloc::Layout::from_size_align(size, align_of::()).unwrap(); + alloc::handle_alloc_error(layout) + }); + + if for_slice { + // Ininitialize the buffer so it can later be treated as `&mut [u8]`. + // SAFETY: The pointer is valid and we are using a valid value for a byte-wise + // allocation. + unsafe { ptr.write_bytes(0, size) }; + } + + Self(ptr.cast()) + } + + pub fn leak(self) -> *mut T { + self.0.cast().as_ptr() + } +} + +pub(super) trait ComBufferExt { + fn to_com_buffer(&self) -> ComBuffer; +} + +impl ComBufferExt for Vec { + fn to_com_buffer(&self) -> ComBuffer { + ComBuffer::from(&self) + } +} + +impl ComBufferExt for &[u8] { + fn to_com_buffer(&self) -> ComBuffer { + ComBuffer::from(self) + } +} + +impl ComBufferExt for Vec { + fn to_com_buffer(&self) -> ComBuffer { + let buffer: Vec = self.into_iter().flat_map(|x| x.to_le_bytes()).collect(); + ComBuffer::from(&buffer) + } +} + +impl ComBufferExt for &[u16] { + fn to_com_buffer(&self) -> ComBuffer { + let buffer: Vec = self + .as_ref() + .into_iter() + .flat_map(|x| x.to_le_bytes()) + .collect(); + ComBuffer::from(&buffer) + } +} + +impl> From for ComBuffer { + fn from(value: T) -> Self { + let buffer: Vec = value + .as_ref() + .into_iter() + .flat_map(|x| x.to_le_bytes()) + .collect(); + let len = buffer.len(); + let com_buffer = Self::alloc(len, true); + // SAFETY: `ptr` points to a valid allocation that `len` matches, and we made sure + // the bytes were initialized. Additionally, bytes have no alignment requirements. + unsafe { + NonNull::slice_from_raw_parts(com_buffer.0.cast::(), len) + .as_mut() + .copy_from_slice(&buffer); + } + com_buffer + } +} + +unsafe fn verify_operation_request( + request: &WEBAUTHN_PLUGIN_OPERATION_REQUEST, + clsid: &GUID, +) -> Result<(), WinWebAuthnError> { + // Verify request + tracing::debug!("Verifying request"); + let request_data = + std::slice::from_raw_parts(request.pbEncodedRequest, request.cbEncodedRequest as usize); + let request_hash = crypto::hash_sha256(request_data).map_err(|err| { + WinWebAuthnError::with_cause(ErrorKind::WindowsInternal, "failed to hash request", err) + })?; + let signature = std::slice::from_raw_parts( + request.pbRequestSignature, + request.cbRequestSignature as usize, + ); + tracing::debug!("Retrieving signing key"); + let op_pub_key = crypto::get_operation_signing_public_key(clsid).map_err(|err| { + WinWebAuthnError::with_cause( + ErrorKind::WindowsInternal, + "Failed to get signing key for operation", + err, + ) + })?; + tracing::debug!("Verifying signature"); + op_pub_key.verify_signature(&request_hash, signature) +} diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs index d0e8455576e..d4b7125e99f 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs @@ -3,18 +3,18 @@ pub(crate) mod crypto; pub(crate) mod types; use std::{error::Error, ptr::NonNull}; -use types::*; -use windows::{ - core::GUID, - Win32::Foundation::{NTE_USER_CANCELLED, S_OK}, -}; +use types::*; pub use types::{ PluginAddAuthenticatorOptions, PluginAddAuthenticatorResponse, PluginCancelOperationRequest, PluginCredentialDetails, PluginGetAssertionRequest, PluginLockStatus, PluginMakeCredentialRequest, PluginMakeCredentialResponse, PluginUserVerificationRequest, PluginUserVerificationResponse, }; +use windows::{ + core::GUID, + Win32::Foundation::{NTE_USER_CANCELLED, S_OK}, +}; use super::{ErrorKind, WinWebAuthnError}; use crate::{ @@ -142,7 +142,8 @@ impl WebAuthnPlugin { })?; if let Some(response) = NonNull::new(response_ptr) { - // SAFETY: The pointer was allocated by a successful call to webauthn_plugin_add_authenticator. + // SAFETY: The pointer was allocated by a successful call to + // webauthn_plugin_add_authenticator. Ok(PluginAddAuthenticatorResponse::try_from_ptr(response)) } else { Err(WinWebAuthnError::new( @@ -189,7 +190,8 @@ impl WebAuthnPlugin { } else { // SAFETY: Windows only runs on platforms where usize >= u32; let len = response_len as usize; - // SAFETY: Windows returned successful response code and length, so we assume that the data is initialized + // SAFETY: Windows returned successful response code and length, so we + // assume that the data is initialized let signature = std::slice::from_raw_parts(response_ptr, len).to_vec(); pub_key.verify_signature(operation_request, &signature)?; signature @@ -295,7 +297,8 @@ impl WebAuthnPlugin { ); } - // SAFETY: The pointer to win_credentials lives longer than the call to webauthn_plugin_authenticator_add_credentials(). + // SAFETY: The pointer to win_credentials lives longer than the call to + // webauthn_plugin_authenticator_add_credentials(). let result = unsafe { webauthn_plugin_authenticator_add_credentials( &self.clsid.0, diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs index dfc91dc38d9..c120dcac236 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs @@ -13,24 +13,20 @@ use windows::{ }; use windows_core::BOOL; +use super::Clsid; use crate::{ plugin::crypto, types::{ - CredentialEx, RpEntityInformation, UserEntityInformation, UserId, - WEBAUTHN_COSE_CREDENTIAL_PARAMETER, + AuthenticatorInfo, CredentialEx, CtapTransport, HmacSecretSalt, RpEntityInformation, + UserEntityInformation, UserId, WebAuthnExtensionMakeCredentialOutput, + WEBAUTHN_COSE_CREDENTIAL_PARAMETER, WEBAUTHN_COSE_CREDENTIAL_PARAMETERS, + WEBAUTHN_CREDENTIAL_ATTESTATION, WEBAUTHN_CREDENTIAL_LIST, WEBAUTHN_EXTENSIONS, + WEBAUTHN_RP_ENTITY_INFORMATION, WEBAUTHN_USER_ENTITY_INFORMATION, }, util::{webauthn_call, WindowsString}, CredentialId, ErrorKind, WinWebAuthnError, }; -use crate::types::{ - AuthenticatorInfo, CtapTransport, HmacSecretSalt, WebAuthnExtensionMakeCredentialOutput, - WEBAUTHN_COSE_CREDENTIAL_PARAMETERS, WEBAUTHN_CREDENTIAL_ATTESTATION, WEBAUTHN_CREDENTIAL_LIST, - WEBAUTHN_EXTENSIONS, WEBAUTHN_RP_ENTITY_INFORMATION, WEBAUTHN_USER_ENTITY_INFORMATION, -}; - -use super::Clsid; - // Plugin Registration types /// Windows WebAuthn Authenticator Options structure @@ -190,7 +186,8 @@ impl PluginAddAuthenticatorResponse { unsafe { std::slice::from_raw_parts( self.inner.as_ref().pbOpSignPubKey, - // SAFETY: We only support 32-bit or 64-bit platforms, so u32 will always fit in usize. + // SAFETY: We only support 32-bit or 64-bit platforms, so u32 will always fit in + // usize. self.inner.as_ref().cbOpSignPubKey as usize, ) } @@ -211,7 +208,8 @@ impl Drop for PluginAddAuthenticatorResponse { fn drop(&mut self) { unsafe { // SAFETY: This should only fail if: - // - we cannot load the webauthn.dll, which we already have if we have constructed this type, or + // - we cannot load the webauthn.dll, which we already have if we have constructed this + // type, or // - we spelled the function wrong, which is a library error. webauthn_plugin_free_add_authenticator_response(self.inner.as_mut()) .expect("function to load properly"); @@ -422,12 +420,14 @@ impl PluginMakeCredentialRequest { pub fn rp_information(&self) -> RpEntityInformation<'_> { let ptr = self.as_ref().pRpInformation; - // SAFETY: When this is constructed using Self::try_from_ptr(), the caller must ensure that pRpInformation is valid. + // SAFETY: When this is constructed using Self::try_from_ptr(), the caller must ensure that + // pRpInformation is valid. unsafe { RpEntityInformation::new(ptr.as_ref().expect("pRpInformation to be non-null")) } } pub fn user_information(&self) -> UserEntityInformation<'_> { - // SAFETY: When this is constructed using Self::try_from_ptr(), the caller must ensure that pUserInformation is valid. + // SAFETY: When this is constructed using Self::try_from_ptr(), the caller must ensure that + // pUserInformation is valid. let ptr = self.as_ref().pUserInformation; assert!(!ptr.is_null()); unsafe { @@ -436,12 +436,14 @@ impl PluginMakeCredentialRequest { } pub fn pub_key_cred_params(&self) -> impl Iterator { - // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API constructs valid pointers. + // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API + // constructs valid pointers. unsafe { self.as_ref().WebAuthNCredentialParameters.iter() } } pub fn exclude_credentials(&self) -> impl Iterator> { - // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API constructs valid pointers. + // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API + // constructs valid pointers. unsafe { self.as_ref().CredentialList.iter() } } @@ -576,37 +578,31 @@ pub struct PluginMakeCredentialResponse { // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_2 - // /// Since VERSION 2 pub extensions: Option>, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_3 - // /// One of the WEBAUTHN_CTAP_TRANSPORT_* bits will be set corresponding to /// the transport that was used. pub used_transport: CtapTransport, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_4 - // pub ep_att: bool, pub large_blob_supported: bool, pub resident_key: bool, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_5 - // pub prf_enabled: bool, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_6 - // pub unsigned_extension_outputs: Option>, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_7 - // pub hmac_secret: Option, /// ThirdPartyPayment Credential or not. @@ -614,7 +610,6 @@ pub struct PluginMakeCredentialResponse { // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_8 - // /// Multiple WEBAUTHN_CTAP_TRANSPORT_* bits will be set corresponding to /// the transports that are supported. pub transports: Option>, @@ -631,7 +626,8 @@ impl PluginMakeCredentialResponse { let attestation = self.try_into()?; let mut response_len = 0; let mut response_ptr = std::ptr::null_mut(); - // SAFETY: we construct valid input and check the OS error code before using the returned value. + // SAFETY: we construct valid input and check the OS error code before using the returned + // value. unsafe { webauthn_encode_make_credential_response( &attestation, @@ -892,7 +888,8 @@ impl PluginGetAssertionRequest { } pub fn allow_credentials(&self) -> impl Iterator> { - // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API constructs valid pointers. + // SAFETY: When this is constructed from Self::try_from_ptr(), the Windows decode API + // constructs valid pointers. unsafe { self.as_ref().CredentialList.iter() } } diff --git a/apps/desktop/desktop_native/win_webauthn/src/types/mod.rs b/apps/desktop/desktop_native/win_webauthn/src/types/mod.rs index 392c5ce21e9..b8213f4a643 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/types/mod.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/types/mod.rs @@ -139,7 +139,8 @@ impl TryFrom<&str> for Uuid { }) .collect::, WinWebAuthnError>>()?; - // SAFETY: We already checked the length of the string before, so this should result in the correct number of bytes. + // SAFETY: We already checked the length of the string before, so this should result in the + // correct number of bytes. let b: [u8; 16] = bytes.try_into().expect("16 bytes to be parsed"); Ok(Uuid(b)) } @@ -179,8 +180,8 @@ pub(crate) struct WEBAUTHN_RP_ENTITY_INFORMATION { /// Contains the friendly name of the Relying Party, such as "Acme /// Corporation", "Widgets Inc" or "Awesome Site". /// - /// This member is deprecated in WebAuthn Level 3 because many clients do not display it, but it - /// remains a required dictionary member for backwards compatibility. Relying + /// This member is deprecated in WebAuthn Level 3 because many clients do not display it, but + /// it remains a required dictionary member for backwards compatibility. Relying /// Parties MAY, as a safe default, set this equal to the RP ID. pwszName: *const u16, // PCWSTR @@ -235,7 +236,8 @@ impl RpEntityInformation<'_> { .to_string() .expect("null-terminated UTF-16 string or null"); - // WebAuthn Level 3 deprecates the use of the `name` field, so verify whether this is empty or not. + // WebAuthn Level 3 deprecates the use of the `name` field, so verify whether this is + // empty or not. if s.is_empty() { None } else { @@ -264,7 +266,8 @@ pub(crate) struct WEBAUTHN_USER_ENTITY_INFORMATION { #[deprecated] pub pwszIcon: Option>, // PCWSTR - /// Contains the friendly name associated with the user account by the Relying Party, such as "John P. Smith". + /// Contains the friendly name associated with the user account by the Relying Party, such as + /// "John P. Smith". pub pwszDisplayName: NonNull, // PCWSTR } @@ -411,39 +414,33 @@ pub(crate) struct WEBAUTHN_CREDENTIAL_ATTESTATION { // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_2 - // /// Since VERSION 2 pub(crate) Extensions: WEBAUTHN_EXTENSIONS, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_3 - // /// One of the WEBAUTHN_CTAP_TRANSPORT_* bits will be set corresponding to /// the transport that was used. pub(crate) dwUsedTransport: u32, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_4 - // pub(crate) bEpAtt: bool, pub(crate) bLargeBlobSupported: bool, pub(crate) bResidentKey: bool, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_5 - // pub(crate) bPrfEnabled: bool, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_6 - // pub(crate) cbUnsignedExtensionOutputs: u32, // _Field_size_bytes_(cbUnsignedExtensionOutputs) pub(crate) pbUnsignedExtensionOutputs: *const u8, // // Following fields have been added in WEBAUTHN_CREDENTIAL_ATTESTATION_VERSION_7 - // pub(crate) pHmacSecret: *const WEBAUTHN_HMAC_SECRET_SALT, // ThirdPartyPayment Credential or not.