diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs index 064d34bd719..e629e002f2d 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/com.rs @@ -12,7 +12,7 @@ use std::{ use windows::{ core::{implement, interface, ComObjectInterface, IUnknown, GUID, HRESULT}, Win32::{ - Foundation::{E_FAIL, E_INVALIDARG, S_OK}, + Foundation::{E_FAIL, E_INVALIDARG, S_FALSE, S_OK}, System::Com::*, }, }; @@ -31,6 +31,7 @@ use crate::{ }; static HANDLER: OnceLock<(GUID, Arc)> = OnceLock::new(); +static SHUTDOWN: OnceLock = OnceLock::new(); #[implement(IClassFactory)] pub struct Factory; @@ -284,6 +285,40 @@ pub(super) fn register_server(clsid: &GUID, handler: T) -> Result<(), WinWebA where T: PluginAuthenticator + Send + Sync + 'static, { + if HANDLER.get().is_some() { + return Err(WinWebAuthnError::new( + ErrorKind::Other, + "server can only be registered one time per process", + )); + } + unsafe { + let com_init_result = CoInitializeEx(None, COINIT_APARTMENTTHREADED); + match com_init_result { + S_OK | S_FALSE => {} // mark initialized, + code => { + return Err(WinWebAuthnError::with_cause( + ErrorKind::WindowsInternal, + "Could not initialize the COM library", + windows::core::Error::from_hresult(code), + )); + } + } + + if let Err(err) = CoInitializeSecurity( + None, + -1, + None, + None, + RPC_C_AUTHN_LEVEL_DEFAULT, + RPC_C_IMP_LEVEL_IMPERSONATE, + None, + EOAC_NONE, + None, + ) { + tracing::warn!("Could not initialize COM security: {err}"); + }; + } + // Store the handler as a static so it can be initialized HANDLER.set((*clsid, Arc::new(handler))).map_err(|_| { WinWebAuthnError::new(ErrorKind::WindowsInternal, "Handler already initialized") @@ -308,43 +343,16 @@ where Ok(()) } -/// Initializes the COM library for use on the calling thread, -/// and registers + sets the security values. -pub(super) fn initialize() -> std::result::Result<(), WinWebAuthnError> { - let result = unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED) }; - - if result.is_err() { - return Err(WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Could not initialize the COM library", - windows::core::Error::from_hresult(result), - )); +pub(super) fn shutdown_server() -> std::result::Result<(), WinWebAuthnError> { + if HANDLER.get().is_some() { + if let Ok(()) = SHUTDOWN.set(true) { + unsafe { CoUninitialize() }; + } else { + tracing::debug!("server already shut down"); + } + } else { + tracing::debug!("server was not registered. Ignoring."); } - - unsafe { - CoInitializeSecurity( - None, - -1, - None, - None, - RPC_C_AUTHN_LEVEL_DEFAULT, - RPC_C_IMP_LEVEL_IMPERSONATE, - None, - EOAC_NONE, - None, - ) - } - .map_err(|err| { - WinWebAuthnError::with_cause( - ErrorKind::WindowsInternal, - "Could not initialize COM security", - err, - ) - }) -} - -pub(super) fn uninitialize() -> std::result::Result<(), WinWebAuthnError> { - unsafe { CoUninitialize() }; Ok(()) } diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs index d4b7125e99f..14d889e26cb 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/mod.rs @@ -76,10 +76,9 @@ impl WebAuthnPlugin { com::register_server(&self.clsid.0, handler) } - /// Initializes the COM library for use on the calling thread, - /// and registers + sets the security values. - pub fn initialize() -> Result<(), WinWebAuthnError> { - com::initialize() + /// Uninitializes the COM library for the calling thread. + pub fn shutdown_server() -> Result<(), WinWebAuthnError> { + com::shutdown_server() } /// Adds this implementation as a Windows WebAuthn plugin. diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/process.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/process.rs index fd339939cd5..59ca7be020a 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/process.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/process.rs @@ -74,19 +74,16 @@ pub(super) fn add_authenticator() -> Result<(), String> { pub(super) fn run_server() -> Result<(), String> { tracing::debug!("Setting up COM server"); - let r = WebAuthnPlugin::initialize(); - tracing::debug!( - "Initialized the com library with WebAuthnPlugin::initialize(): {:?}", - r - ); let clsid = CLSID.try_into().expect("valid GUID string"); let plugin = WebAuthnPlugin::new(clsid); - let r = plugin.register_server(BitwardenPluginAuthenticator { - client: Mutex::new(None), - callbacks: Arc::new(Mutex::new(HashMap::new())), - }); + let r = plugin + .register_server(BitwardenPluginAuthenticator { + client: Mutex::new(None), + callbacks: Arc::new(Mutex::new(HashMap::new())), + }) + .map_err(|err| err.to_string())?; tracing::debug!("Registered the com library: {:?}", r); Ok(()) }