diff --git a/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs b/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs index a69ea62f01f..687fd601337 100644 --- a/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs +++ b/apps/desktop/desktop_native/win_webauthn/src/plugin/types.rs @@ -892,12 +892,12 @@ pub struct PluginCancelOperationRequest { impl PluginCancelOperationRequest { /// Request transaction ID - fn transaction_id(&self) -> GUID { + pub fn transaction_id(&self) -> GUID { self.as_ref().transactionId } /// Request signature. - fn request_signature(&self) -> &[u8] { + pub fn request_signature(&self) -> &[u8] { unsafe { std::slice::from_raw_parts( self.as_ref().pbRequestSignature, diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs index c3e5c0df3a8..579faddc841 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/assert.rs @@ -1,11 +1,14 @@ use serde_json; -use std::{sync::Arc, time::Duration}; +use std::{ + sync::{mpsc::Receiver, Arc}, + time::Duration, +}; use win_webauthn::plugin::PluginGetAssertionRequest; use crate::{ ipc2::{ - PasskeyAssertionRequest, PasskeyAssertionResponse, + CallbackError, PasskeyAssertionRequest, PasskeyAssertionResponse, PasskeyAssertionWithoutUserInterfaceRequest, Position, TimedCallback, UserVerification, WindowsProviderClient, }, @@ -15,6 +18,7 @@ use crate::{ pub fn get_assertion( ipc_client: &WindowsProviderClient, request: PluginGetAssertionRequest, + cancellation_token: Receiver<()>, ) -> Result, Box> { // Extract RP information let rp_id = request.rp_id().to_string(); @@ -64,8 +68,9 @@ pub fn get_assertion( }, context: transaction_id, }; - let passkey_response = send_assertion_request(ipc_client, assertion_request) - .map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?; + let passkey_response = + send_assertion_request(ipc_client, assertion_request, cancellation_token) + .map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?; tracing::debug!("Assertion response received: {:?}", passkey_response); // Create proper WebAuthn response from passkey_response @@ -84,6 +89,7 @@ pub fn get_assertion( fn send_assertion_request( ipc_client: &WindowsProviderClient, request: PasskeyAssertionRequest, + cancellation_token: Receiver<()>, ) -> Result { tracing::debug!( "Assertion request data - RP ID: {}, Client data hash: {} bytes, Allowed credentials: {:?}", @@ -113,9 +119,13 @@ fn send_assertion_request( } else { ipc_client.prepare_passkey_assertion(request, callback.clone()); } + let wait_time = Duration::from_secs(600); callback - .wait_for_response(Duration::from_secs(30)) - .map_err(|_| "Registration request timed out".to_string())? + .wait_for_response(wait_time, Some(cancellation_token)) + .map_err(|err| match err { + CallbackError::Timeout => "Assertion request timed out".to_string(), + CallbackError::Cancelled => "Assertion request cancelled".to_string(), + })? .map_err(|err| err.to_string()) } diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs index b4ecedab27f..bd44666fcd3 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs @@ -4,7 +4,7 @@ use std::{ fmt::Display, sync::{ atomic::AtomicU32, - mpsc::{self, Receiver, Sender}, + mpsc::{self, Receiver, RecvError, RecvTimeoutError, Sender}, Arc, Mutex, }, time::{Duration, Instant}, @@ -285,25 +285,60 @@ impl WindowsProviderClient { } } -pub struct TimedCallback { - tx: Mutex>>>, - rx: Mutex>>, +pub enum CallbackError { + Timeout, + Cancelled, } -impl TimedCallback { +pub struct TimedCallback { + tx: Arc>>>>, + rx: Arc>>>, +} + +impl TimedCallback { pub fn new() -> Self { let (tx, rx) = mpsc::channel(); Self { - tx: Mutex::new(Some(tx)), - rx: Mutex::new(rx), + tx: Arc::new(Mutex::new(Some(tx))), + rx: Arc::new(Mutex::new(rx)), } } pub fn wait_for_response( &self, timeout: Duration, - ) -> Result, mpsc::RecvTimeoutError> { - self.rx.lock().unwrap().recv_timeout(timeout) + cancellation_token: Option>, + ) -> Result, CallbackError> { + let (tx, rx) = mpsc::channel(); + if let Some(cancellation_token) = cancellation_token { + let tx2 = tx.clone(); + let cancellation_token = Mutex::new(cancellation_token); + std::thread::spawn(move || { + if let Ok(()) = cancellation_token.lock().unwrap().recv_timeout(timeout) { + tracing::debug!("Forwarding cancellation"); + _ = tx2.send(Err(CallbackError::Cancelled)); + } + }); + } + let response_rx = self.rx.clone(); + std::thread::spawn(move || { + if let Ok(response) = response_rx.lock().unwrap().recv_timeout(timeout) { + _ = tx.send(Ok(response)); + } + }); + match rx.recv_timeout(timeout) { + Ok(Ok(response)) => Ok(response), + Ok(err @ Err(CallbackError::Cancelled)) => { + tracing::debug!("Received cancellation, dropping."); + err + } + Ok(err @ Err(CallbackError::Timeout)) => { + tracing::debug!("Request timed out, dropping."); + err + } + Err(RecvTimeoutError::Timeout) => Err(CallbackError::Timeout), + Err(_) => Err(CallbackError::Cancelled), + } } fn send(&self, response: Result) { diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs index be92d1df1ca..9fc2c59b34f 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs @@ -9,11 +9,16 @@ mod make_credential; mod types; mod util; -use std::{collections::HashSet, sync::Arc, time::Duration}; - -// Re-export main functionality -pub use types::UserVerificationRequirement; +use std::{ + collections::{HashMap, HashSet}, + sync::{ + mpsc::{self, Sender}, + Arc, Mutex, + }, + time::Duration, +}; +use base64::engine::{general_purpose::STANDARD, Engine as _}; use win_webauthn::{ plugin::{ PluginAddAuthenticatorOptions, PluginAuthenticator, PluginCancelOperationRequest, @@ -21,8 +26,12 @@ use win_webauthn::{ }, AuthenticatorInfo, CtapVersion, PublicKeyCredentialParameters, }; +use windows_core::GUID; -use crate::ipc2::{TimedCallback, WindowsProviderClient}; +use crate::ipc2::{ConnectionStatus, TimedCallback, WindowsProviderClient}; + +// Re-export main functionality +pub use types::UserVerificationRequirement; const AUTHENTICATOR_NAME: &str = "Bitwarden Desktop"; const RPID: &str = "bitwarden.com"; @@ -45,7 +54,10 @@ pub fn register() -> std::result::Result<(), String> { let clsid = CLSID.try_into().expect("valid GUID string"); let plugin = WebAuthnPlugin::new(clsid); - let r = plugin.register_server(BitwardenPluginAuthenticator); + let r = plugin.register_server(BitwardenPluginAuthenticator { + client: Mutex::new(None), + callbacks: Arc::new(Mutex::new(HashMap::new())), + }); tracing::debug!("Registered the com library: {:?}", r); tracing::debug!("Parsing authenticator options"); @@ -83,14 +95,28 @@ pub fn register() -> std::result::Result<(), String> { Ok(()) } -struct BitwardenPluginAuthenticator; +struct BitwardenPluginAuthenticator { + /// Client to communicate with desktop app over IPC. + client: Mutex>>, + + /// Map of transaction IDs to cancellation tokens + callbacks: Arc>>>, +} impl BitwardenPluginAuthenticator { - fn get_client(&self) -> WindowsProviderClient { + fn get_client(&self) -> Arc { tracing::debug!("Connecting to client via IPC"); - let client = WindowsProviderClient::connect(); - tracing::debug!("Connected to client via IPC successfully"); - client + let mut client = self.client.lock().unwrap(); + match client.as_ref().map(|c| (c, c.get_connection_status())) { + Some((_, ConnectionStatus::Disconnected)) | None => { + tracing::debug!("Connecting to desktop app"); + let c = WindowsProviderClient::connect(); + tracing::debug!("Connected to client via IPC successfully"); + _ = client.insert(Arc::new(c)); + } + _ => {} + }; + client.as_ref().unwrap().clone() } } @@ -101,7 +127,18 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { ) -> Result, Box> { tracing::debug!("Received MakeCredential: {request:?}"); let client = self.get_client(); - make_credential::make_credential(&client, request) + let (cancel_tx, cancel_rx) = mpsc::channel(); + let transaction_id = request.transaction_id; + self.callbacks + .lock() + .expect("not poisoned") + .insert(transaction_id, cancel_tx); + let response = make_credential::make_credential(&client, request, cancel_rx); + self.callbacks + .lock() + .expect("not poisoned") + .remove(&transaction_id); + response } fn get_assertion( @@ -110,13 +147,39 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { ) -> Result, Box> { tracing::debug!("Received GetAssertion: {request:?}"); let client = self.get_client(); - assert::get_assertion(&client, request) + let (cancel_tx, cancel_rx) = mpsc::channel(); + let transaction_id = request.transaction_id; + self.callbacks + .lock() + .expect("not poisoned") + .insert(transaction_id, cancel_tx); + let response = assert::get_assertion(&client, request, cancel_rx); + self.callbacks + .lock() + .expect("not poisoned") + .remove(&transaction_id); + response } fn cancel_operation( &self, request: PluginCancelOperationRequest, ) -> Result<(), Box> { + let transaction_id = request.transaction_id(); + tracing::debug!(?transaction_id, "Received CancelOperation"); + + if let Some(cancellation_token) = self + .callbacks + .lock() + .expect("not poisoned") + .get(&request.transaction_id()) + { + _ = cancellation_token.send(()); + let client = self.get_client(); + let context = STANDARD.encode(transaction_id.to_u128().to_le_bytes().to_vec()); + tracing::debug!("Sending cancel operation for context: {context}"); + client.send_native_status("cancel-operation".to_string(), context); + } Ok(()) } @@ -124,7 +187,7 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { let callback = Arc::new(TimedCallback::new()); let client = self.get_client(); client.get_lock_status(callback.clone()); - match callback.wait_for_response(Duration::from_secs(3)) { + match callback.wait_for_response(Duration::from_secs(3), None) { Ok(Ok(response)) => { if response.is_unlocked { Ok(PluginLockStatus::PluginUnlocked) diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/make_credential.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/make_credential.rs index 9cfc7448705..53bbdfc1a20 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/make_credential.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/make_credential.rs @@ -1,6 +1,7 @@ use serde_json; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::mpsc::TryRecvError; +use std::sync::{mpsc::Receiver, Arc}; use std::time::Duration; use win_webauthn::{ @@ -8,6 +9,7 @@ use win_webauthn::{ CtapTransport, }; +use crate::ipc2::CallbackError; use crate::{ ipc2::{ PasskeyRegistrationRequest, PasskeyRegistrationResponse, Position, TimedCallback, @@ -19,6 +21,7 @@ use crate::{ pub fn make_credential( ipc_client: &WindowsProviderClient, request: PluginMakeCredentialRequest, + cancellation_token: Receiver<()>, ) -> Result, Box> { tracing::debug!("=== PluginMakeCredential() called ==="); @@ -113,9 +116,14 @@ pub fn make_credential( registration_request.user_name ); + if let Ok(()) = cancellation_token.try_recv() { + return Err(format!("Request {:?} cancelled", request.transaction_id))?; + } + // Send registration request - let passkey_response = send_registration_request(ipc_client, registration_request) - .map_err(|err| format!("Registration request failed: {err}"))?; + let passkey_response = + send_registration_request(ipc_client, registration_request, cancellation_token) + .map_err(|err| format!("Registration request failed: {err}"))?; tracing::debug!("Registration response received: {:?}", passkey_response); // Create proper WebAuthn response from passkey_response @@ -130,6 +138,7 @@ pub fn make_credential( fn send_registration_request( ipc_client: &WindowsProviderClient, request: PasskeyRegistrationRequest, + cancellation_token: Receiver<()>, ) -> Result { tracing::debug!("Registration request data - RP ID: {}, User ID: {} bytes, User name: {}, Client data hash: {} bytes, Algorithms: {:?}, Excluded credentials: {}", request.rp_id, request.user_handle.len(), request.user_name, request.client_data_hash.len(), request.supported_algorithms, request.excluded_credentials.len()); @@ -139,9 +148,15 @@ fn send_registration_request( tracing::debug!("Sending registration request: {}", request_json); let callback = Arc::new(TimedCallback::new()); ipc_client.prepare_passkey_registration(request, callback.clone()); + // Corresponds to maximum recommended timeout for WebAuthn. + // https://www.w3.org/TR/webauthn-3/#recommended-range-and-default-for-a-webauthn-ceremony-timeout + let wait_time = Duration::from_secs(600); let response = callback - .wait_for_response(Duration::from_secs(30)) - .map_err(|_| "Registration request timed out".to_string())? + .wait_for_response(wait_time, Some(cancellation_token)) + .map_err(|err| match err { + CallbackError::Timeout => "Registration request timed out".to_string(), + CallbackError::Cancelled => "Registration request cancelled".to_string(), + })? .map_err(|err| err.to_string()); if response.is_ok() { tracing::debug!("Requesting credential sync after registering a new credential."); diff --git a/apps/desktop/src/autofill/services/desktop-autofill.service.ts b/apps/desktop/src/autofill/services/desktop-autofill.service.ts index 6ea86f26d98..952e4bde04a 100644 --- a/apps/desktop/src/autofill/services/desktop-autofill.service.ts +++ b/apps/desktop/src/autofill/services/desktop-autofill.service.ts @@ -54,6 +54,7 @@ const NativeCredentialSyncFeatureFlag = ipc.platform.deviceType === DeviceType.W export class DesktopAutofillService implements OnDestroy { private destroy$ = new Subject(); private registrationRequest: autofill.PasskeyRegistrationRequest; + private inFlightRequests: Record = {}; constructor( private logService: LogService, @@ -210,6 +211,12 @@ export class DesktopAutofillService implements OnDestroy { this.logService.debug("listenPasskeyRegistration2", this.convertRegistrationRequest(request)); const controller = new AbortController(); + let requestId = request.context ? this.contextToRequestId(request.context) : null; + this.logService.debug("Request context:", requestId) + if (requestId) { + this.inFlightRequests[requestId] = controller; + } + const ctx = request.context ? new Uint8Array(request.context).buffer : null; try { @@ -226,6 +233,11 @@ export class DesktopAutofillService implements OnDestroy { this.logService.error("listenPasskeyRegistration error", error); callback(error, null); } + finally { + if (requestId) { + delete this.inFlightRequests[requestId]; + } + } this.logService.info("Passkey registration completed.") }); @@ -247,6 +259,10 @@ export class DesktopAutofillService implements OnDestroy { ); const controller = new AbortController(); + let requestId = request.context ? this.contextToRequestId(request.context) : null; + if (requestId) { + this.inFlightRequests[requestId] = controller; + } try { // For some reason the credentialId is passed as an empty array in the request, so we need to @@ -296,6 +312,11 @@ export class DesktopAutofillService implements OnDestroy { callback(error, null); return; } + finally { + if (requestId) { + delete this.inFlightRequests[requestId]; + } + } }, ); @@ -311,6 +332,11 @@ export class DesktopAutofillService implements OnDestroy { this.logService.debug("listenPasskeyAssertion", clientId, sequenceNumber, request); const ctx = request.context ? new Uint8Array(request.context).buffer : null; const controller = new AbortController(); + let requestId = request.context ? this.contextToRequestId(request.context) : null; + if (requestId) { + this.inFlightRequests[requestId] = controller; + } + try { const response = await this.fido2AuthenticatorService.getAssertion( this.convertAssertionRequest(request), @@ -324,6 +350,11 @@ export class DesktopAutofillService implements OnDestroy { this.logService.error("listenPasskeyAssertion error", error); callback(error, null); } + finally { + if (requestId) { + delete this.inFlightRequests[requestId]; + } + } }); // Listen for native status messages @@ -340,6 +371,18 @@ export class DesktopAutofillService implements OnDestroy { // perform ad-hoc sync await this.adHocSync(); } + + if (status.key === "cancel-operation" && status.value) { + const requestId = status.value + const controller = this.inFlightRequests[requestId] + if (controller) { + this.logService.debug(`Cancelling request ${requestId}`); + controller.abort("Operation cancelled") + } + else { + this.logService.debug(`Unknown request: ${requestId}`); + } + } }); ipc.autofill.listenLockStatusQuery(async (clientId, sequenceNumber, request, callback) => { @@ -457,6 +500,12 @@ export class DesktopAutofillService implements OnDestroy { }; } + private contextToRequestId(context: number[]): string { + const buf = new Uint8Array(context).buffer; + const requestId = Utils.fromBufferToB64(buf); + return requestId + } + ngOnDestroy(): void { this.destroy$.next(); this.destroy$.complete(); diff --git a/apps/desktop/src/autofill/services/desktop-fido2-user-interface.service.ts b/apps/desktop/src/autofill/services/desktop-fido2-user-interface.service.ts index 8e393567039..629c00881b5 100644 --- a/apps/desktop/src/autofill/services/desktop-fido2-user-interface.service.ts +++ b/apps/desktop/src/autofill/services/desktop-fido2-user-interface.service.ts @@ -78,6 +78,7 @@ export class DesktopFido2UserInterfaceService this.router, this.desktopSettingsService, nativeWindowObject, + abortController, transactionContext, ); @@ -95,6 +96,7 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi private router: Router, private desktopSettingsService: DesktopSettingsService, private windowObject: NativeWindowObject, + private abortController: AbortController, private transactionContext: ArrayBuffer, ) {} @@ -177,15 +179,22 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi private async waitForUiChosenCipher( timeoutMs: number = 60000, ): Promise<{ cipherId?: string; userVerified: boolean } | undefined> { + const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation(); try { - return await lastValueFrom(this.chosenCipherSubject.pipe(timeout(timeoutMs))); - } catch { + this.abortController.signal.throwIfAborted(); + const confirmPromise = lastValueFrom(this.chosenCipherSubject.pipe(timeout(timeoutMs))); + return await Promise.race([confirmPromise, cancelPromise]); + } catch (e) { // If we hit a timeout, return undefined instead of throwing + this.logService.debug("Timed out or cancelled?", e); this.logService.warning("Timeout: User did not select a cipher within the allowed time", { timeoutMs, }); return { cipherId: undefined, userVerified: false }; } + finally { + this.unsusbscribeCancellation(abortFn); + } } /** @@ -204,7 +213,19 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi * @returns */ private async waitForUiNewCredentialConfirmation(): Promise { - return lastValueFrom(this.confirmCredentialSubject); + const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation(); + try { + this.abortController.signal.throwIfAborted(); + const confirmPromise = lastValueFrom(this.confirmCredentialSubject); + return await Promise.race([confirmPromise, cancelPromise]); + } catch (e) { + // If we hit a timeout, return undefined instead of throwing + this.logService.debug("Timed out or cancelled?", e); + return undefined; + } + finally { + this.unsusbscribeCancellation(abortFn); + } } /** @@ -375,17 +396,21 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi await this.showUi("/lock", this.windowObject.windowXy, true, true); let status2: AuthenticationStatus; + const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation(); try { - status2 = await lastValueFrom( + status2 = await Promise.race([lastValueFrom( this.authService.activeAccountStatus$.pipe( filter((s) => s === AuthenticationStatus.Unlocked), take(1), timeout(1000 * 60 * 5), // 5 minutes ), - ); + ), cancelPromise]); } catch (error) { this.logService.warning("Error while waiting for vault to unlock", error); } + finally { + this.unsusbscribeCancellation(abortFn); + } if (status2 === AuthenticationStatus.Unlocked) { await this.router.navigate(["/"]); @@ -405,4 +430,24 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi async close() { this.logService.debug("close"); } + + subscribeToCancellation() { + let cancelReject: (reason?: any) => void; + const cancelPromise: Promise = new Promise((_, reject) => { + cancelReject = reject + }); + const abortFn = (ev: Event) => { + if (ev.target instanceof AbortSignal) { + cancelReject(ev.target.reason) + } + }; + this.abortController.signal.addEventListener("abort", abortFn, { once: true }); + + return { promise: cancelPromise, listener: abortFn }; + + } + + unsusbscribeCancellation(listener: (ev: Event) => void): void { + this.abortController.signal.removeEventListener("abort", listener); + } } diff --git a/libs/common/src/platform/services/fido2/fido2-authenticator.service.ts b/libs/common/src/platform/services/fido2/fido2-authenticator.service.ts index 54b7fdbeb93..2f7fc9b0867 100644 --- a/libs/common/src/platform/services/fido2/fido2-authenticator.service.ts +++ b/libs/common/src/platform/services/fido2/fido2-authenticator.service.ts @@ -460,10 +460,8 @@ export class Fido2AuthenticatorService } private async findCredentialsByRp(rpId: string): Promise { - this.logService.debug("[findCredentialByRp]:", rpId) const activeUserId = await firstValueFrom(this.accountService.activeAccount$.pipe(getUserId)); const ciphers = await this.cipherService.getAllDecrypted(activeUserId); - this.logService.debug("[findCredentialsByRp] ciphers:", ciphers) return ciphers.filter( (cipher) => !cipher.isDeleted &&