1
0
mirror of https://github.com/bitwarden/browser synced 2026-02-13 15:03:26 +00:00

Implement Windows plugin WebAuthn cancellation

This commit is contained in:
Isaiah Inuwa
2025-11-24 21:37:23 -06:00
parent 926168e97e
commit 46f990c340
8 changed files with 258 additions and 43 deletions

View File

@@ -892,12 +892,12 @@ pub struct PluginCancelOperationRequest {
impl PluginCancelOperationRequest {
/// Request transaction ID
fn transaction_id(&self) -> GUID {
pub fn transaction_id(&self) -> GUID {
self.as_ref().transactionId
}
/// Request signature.
fn request_signature(&self) -> &[u8] {
pub fn request_signature(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
self.as_ref().pbRequestSignature,

View File

@@ -1,11 +1,14 @@
use serde_json;
use std::{sync::Arc, time::Duration};
use std::{
sync::{mpsc::Receiver, Arc},
time::Duration,
};
use win_webauthn::plugin::PluginGetAssertionRequest;
use crate::{
ipc2::{
PasskeyAssertionRequest, PasskeyAssertionResponse,
CallbackError, PasskeyAssertionRequest, PasskeyAssertionResponse,
PasskeyAssertionWithoutUserInterfaceRequest, Position, TimedCallback, UserVerification,
WindowsProviderClient,
},
@@ -15,6 +18,7 @@ use crate::{
pub fn get_assertion(
ipc_client: &WindowsProviderClient,
request: PluginGetAssertionRequest,
cancellation_token: Receiver<()>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
// Extract RP information
let rp_id = request.rp_id().to_string();
@@ -64,8 +68,9 @@ pub fn get_assertion(
},
context: transaction_id,
};
let passkey_response = send_assertion_request(ipc_client, assertion_request)
.map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?;
let passkey_response =
send_assertion_request(ipc_client, assertion_request, cancellation_token)
.map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?;
tracing::debug!("Assertion response received: {:?}", passkey_response);
// Create proper WebAuthn response from passkey_response
@@ -84,6 +89,7 @@ pub fn get_assertion(
fn send_assertion_request(
ipc_client: &WindowsProviderClient,
request: PasskeyAssertionRequest,
cancellation_token: Receiver<()>,
) -> Result<PasskeyAssertionResponse, String> {
tracing::debug!(
"Assertion request data - RP ID: {}, Client data hash: {} bytes, Allowed credentials: {:?}",
@@ -113,9 +119,13 @@ fn send_assertion_request(
} else {
ipc_client.prepare_passkey_assertion(request, callback.clone());
}
let wait_time = Duration::from_secs(600);
callback
.wait_for_response(Duration::from_secs(30))
.map_err(|_| "Registration request timed out".to_string())?
.wait_for_response(wait_time, Some(cancellation_token))
.map_err(|err| match err {
CallbackError::Timeout => "Assertion request timed out".to_string(),
CallbackError::Cancelled => "Assertion request cancelled".to_string(),
})?
.map_err(|err| err.to_string())
}

View File

@@ -4,7 +4,7 @@ use std::{
fmt::Display,
sync::{
atomic::AtomicU32,
mpsc::{self, Receiver, Sender},
mpsc::{self, Receiver, RecvError, RecvTimeoutError, Sender},
Arc, Mutex,
},
time::{Duration, Instant},
@@ -285,25 +285,60 @@ impl WindowsProviderClient {
}
}
pub struct TimedCallback<T> {
tx: Mutex<Option<Sender<Result<T, BitwardenError>>>>,
rx: Mutex<Receiver<Result<T, BitwardenError>>>,
pub enum CallbackError {
Timeout,
Cancelled,
}
impl<T> TimedCallback<T> {
pub struct TimedCallback<T> {
tx: Arc<Mutex<Option<Sender<Result<T, BitwardenError>>>>>,
rx: Arc<Mutex<Receiver<Result<T, BitwardenError>>>>,
}
impl<T: Send + 'static> TimedCallback<T> {
pub fn new() -> Self {
let (tx, rx) = mpsc::channel();
Self {
tx: Mutex::new(Some(tx)),
rx: Mutex::new(rx),
tx: Arc::new(Mutex::new(Some(tx))),
rx: Arc::new(Mutex::new(rx)),
}
}
pub fn wait_for_response(
&self,
timeout: Duration,
) -> Result<Result<T, BitwardenError>, mpsc::RecvTimeoutError> {
self.rx.lock().unwrap().recv_timeout(timeout)
cancellation_token: Option<Receiver<()>>,
) -> Result<Result<T, BitwardenError>, CallbackError> {
let (tx, rx) = mpsc::channel();
if let Some(cancellation_token) = cancellation_token {
let tx2 = tx.clone();
let cancellation_token = Mutex::new(cancellation_token);
std::thread::spawn(move || {
if let Ok(()) = cancellation_token.lock().unwrap().recv_timeout(timeout) {
tracing::debug!("Forwarding cancellation");
_ = tx2.send(Err(CallbackError::Cancelled));
}
});
}
let response_rx = self.rx.clone();
std::thread::spawn(move || {
if let Ok(response) = response_rx.lock().unwrap().recv_timeout(timeout) {
_ = tx.send(Ok(response));
}
});
match rx.recv_timeout(timeout) {
Ok(Ok(response)) => Ok(response),
Ok(err @ Err(CallbackError::Cancelled)) => {
tracing::debug!("Received cancellation, dropping.");
err
}
Ok(err @ Err(CallbackError::Timeout)) => {
tracing::debug!("Request timed out, dropping.");
err
}
Err(RecvTimeoutError::Timeout) => Err(CallbackError::Timeout),
Err(_) => Err(CallbackError::Cancelled),
}
}
fn send(&self, response: Result<T, BitwardenError>) {

View File

@@ -9,11 +9,16 @@ mod make_credential;
mod types;
mod util;
use std::{collections::HashSet, sync::Arc, time::Duration};
// Re-export main functionality
pub use types::UserVerificationRequirement;
use std::{
collections::{HashMap, HashSet},
sync::{
mpsc::{self, Sender},
Arc, Mutex,
},
time::Duration,
};
use base64::engine::{general_purpose::STANDARD, Engine as _};
use win_webauthn::{
plugin::{
PluginAddAuthenticatorOptions, PluginAuthenticator, PluginCancelOperationRequest,
@@ -21,8 +26,12 @@ use win_webauthn::{
},
AuthenticatorInfo, CtapVersion, PublicKeyCredentialParameters,
};
use windows_core::GUID;
use crate::ipc2::{TimedCallback, WindowsProviderClient};
use crate::ipc2::{ConnectionStatus, TimedCallback, WindowsProviderClient};
// Re-export main functionality
pub use types::UserVerificationRequirement;
const AUTHENTICATOR_NAME: &str = "Bitwarden Desktop";
const RPID: &str = "bitwarden.com";
@@ -45,7 +54,10 @@ pub fn register() -> std::result::Result<(), String> {
let clsid = CLSID.try_into().expect("valid GUID string");
let plugin = WebAuthnPlugin::new(clsid);
let r = plugin.register_server(BitwardenPluginAuthenticator);
let r = plugin.register_server(BitwardenPluginAuthenticator {
client: Mutex::new(None),
callbacks: Arc::new(Mutex::new(HashMap::new())),
});
tracing::debug!("Registered the com library: {:?}", r);
tracing::debug!("Parsing authenticator options");
@@ -83,14 +95,28 @@ pub fn register() -> std::result::Result<(), String> {
Ok(())
}
struct BitwardenPluginAuthenticator;
struct BitwardenPluginAuthenticator {
/// Client to communicate with desktop app over IPC.
client: Mutex<Option<Arc<WindowsProviderClient>>>,
/// Map of transaction IDs to cancellation tokens
callbacks: Arc<Mutex<HashMap<GUID, Sender<()>>>>,
}
impl BitwardenPluginAuthenticator {
fn get_client(&self) -> WindowsProviderClient {
fn get_client(&self) -> Arc<WindowsProviderClient> {
tracing::debug!("Connecting to client via IPC");
let client = WindowsProviderClient::connect();
tracing::debug!("Connected to client via IPC successfully");
client
let mut client = self.client.lock().unwrap();
match client.as_ref().map(|c| (c, c.get_connection_status())) {
Some((_, ConnectionStatus::Disconnected)) | None => {
tracing::debug!("Connecting to desktop app");
let c = WindowsProviderClient::connect();
tracing::debug!("Connected to client via IPC successfully");
_ = client.insert(Arc::new(c));
}
_ => {}
};
client.as_ref().unwrap().clone()
}
}
@@ -101,7 +127,18 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator {
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
tracing::debug!("Received MakeCredential: {request:?}");
let client = self.get_client();
make_credential::make_credential(&client, request)
let (cancel_tx, cancel_rx) = mpsc::channel();
let transaction_id = request.transaction_id;
self.callbacks
.lock()
.expect("not poisoned")
.insert(transaction_id, cancel_tx);
let response = make_credential::make_credential(&client, request, cancel_rx);
self.callbacks
.lock()
.expect("not poisoned")
.remove(&transaction_id);
response
}
fn get_assertion(
@@ -110,13 +147,39 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator {
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
tracing::debug!("Received GetAssertion: {request:?}");
let client = self.get_client();
assert::get_assertion(&client, request)
let (cancel_tx, cancel_rx) = mpsc::channel();
let transaction_id = request.transaction_id;
self.callbacks
.lock()
.expect("not poisoned")
.insert(transaction_id, cancel_tx);
let response = assert::get_assertion(&client, request, cancel_rx);
self.callbacks
.lock()
.expect("not poisoned")
.remove(&transaction_id);
response
}
fn cancel_operation(
&self,
request: PluginCancelOperationRequest,
) -> Result<(), Box<dyn std::error::Error>> {
let transaction_id = request.transaction_id();
tracing::debug!(?transaction_id, "Received CancelOperation");
if let Some(cancellation_token) = self
.callbacks
.lock()
.expect("not poisoned")
.get(&request.transaction_id())
{
_ = cancellation_token.send(());
let client = self.get_client();
let context = STANDARD.encode(transaction_id.to_u128().to_le_bytes().to_vec());
tracing::debug!("Sending cancel operation for context: {context}");
client.send_native_status("cancel-operation".to_string(), context);
}
Ok(())
}
@@ -124,7 +187,7 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator {
let callback = Arc::new(TimedCallback::new());
let client = self.get_client();
client.get_lock_status(callback.clone());
match callback.wait_for_response(Duration::from_secs(3)) {
match callback.wait_for_response(Duration::from_secs(3), None) {
Ok(Ok(response)) => {
if response.is_unlocked {
Ok(PluginLockStatus::PluginUnlocked)

View File

@@ -1,6 +1,7 @@
use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::mpsc::TryRecvError;
use std::sync::{mpsc::Receiver, Arc};
use std::time::Duration;
use win_webauthn::{
@@ -8,6 +9,7 @@ use win_webauthn::{
CtapTransport,
};
use crate::ipc2::CallbackError;
use crate::{
ipc2::{
PasskeyRegistrationRequest, PasskeyRegistrationResponse, Position, TimedCallback,
@@ -19,6 +21,7 @@ use crate::{
pub fn make_credential(
ipc_client: &WindowsProviderClient,
request: PluginMakeCredentialRequest,
cancellation_token: Receiver<()>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
tracing::debug!("=== PluginMakeCredential() called ===");
@@ -113,9 +116,14 @@ pub fn make_credential(
registration_request.user_name
);
if let Ok(()) = cancellation_token.try_recv() {
return Err(format!("Request {:?} cancelled", request.transaction_id))?;
}
// Send registration request
let passkey_response = send_registration_request(ipc_client, registration_request)
.map_err(|err| format!("Registration request failed: {err}"))?;
let passkey_response =
send_registration_request(ipc_client, registration_request, cancellation_token)
.map_err(|err| format!("Registration request failed: {err}"))?;
tracing::debug!("Registration response received: {:?}", passkey_response);
// Create proper WebAuthn response from passkey_response
@@ -130,6 +138,7 @@ pub fn make_credential(
fn send_registration_request(
ipc_client: &WindowsProviderClient,
request: PasskeyRegistrationRequest,
cancellation_token: Receiver<()>,
) -> Result<PasskeyRegistrationResponse, String> {
tracing::debug!("Registration request data - RP ID: {}, User ID: {} bytes, User name: {}, Client data hash: {} bytes, Algorithms: {:?}, Excluded credentials: {}",
request.rp_id, request.user_handle.len(), request.user_name, request.client_data_hash.len(), request.supported_algorithms, request.excluded_credentials.len());
@@ -139,9 +148,15 @@ fn send_registration_request(
tracing::debug!("Sending registration request: {}", request_json);
let callback = Arc::new(TimedCallback::new());
ipc_client.prepare_passkey_registration(request, callback.clone());
// Corresponds to maximum recommended timeout for WebAuthn.
// https://www.w3.org/TR/webauthn-3/#recommended-range-and-default-for-a-webauthn-ceremony-timeout
let wait_time = Duration::from_secs(600);
let response = callback
.wait_for_response(Duration::from_secs(30))
.map_err(|_| "Registration request timed out".to_string())?
.wait_for_response(wait_time, Some(cancellation_token))
.map_err(|err| match err {
CallbackError::Timeout => "Registration request timed out".to_string(),
CallbackError::Cancelled => "Registration request cancelled".to_string(),
})?
.map_err(|err| err.to_string());
if response.is_ok() {
tracing::debug!("Requesting credential sync after registering a new credential.");

View File

@@ -54,6 +54,7 @@ const NativeCredentialSyncFeatureFlag = ipc.platform.deviceType === DeviceType.W
export class DesktopAutofillService implements OnDestroy {
private destroy$ = new Subject<void>();
private registrationRequest: autofill.PasskeyRegistrationRequest;
private inFlightRequests: Record<string, AbortController> = {};
constructor(
private logService: LogService,
@@ -210,6 +211,12 @@ export class DesktopAutofillService implements OnDestroy {
this.logService.debug("listenPasskeyRegistration2", this.convertRegistrationRequest(request));
const controller = new AbortController();
let requestId = request.context ? this.contextToRequestId(request.context) : null;
this.logService.debug("Request context:", requestId)
if (requestId) {
this.inFlightRequests[requestId] = controller;
}
const ctx = request.context ? new Uint8Array(request.context).buffer : null;
try {
@@ -226,6 +233,11 @@ export class DesktopAutofillService implements OnDestroy {
this.logService.error("listenPasskeyRegistration error", error);
callback(error, null);
}
finally {
if (requestId) {
delete this.inFlightRequests[requestId];
}
}
this.logService.info("Passkey registration completed.")
});
@@ -247,6 +259,10 @@ export class DesktopAutofillService implements OnDestroy {
);
const controller = new AbortController();
let requestId = request.context ? this.contextToRequestId(request.context) : null;
if (requestId) {
this.inFlightRequests[requestId] = controller;
}
try {
// For some reason the credentialId is passed as an empty array in the request, so we need to
@@ -296,6 +312,11 @@ export class DesktopAutofillService implements OnDestroy {
callback(error, null);
return;
}
finally {
if (requestId) {
delete this.inFlightRequests[requestId];
}
}
},
);
@@ -311,6 +332,11 @@ export class DesktopAutofillService implements OnDestroy {
this.logService.debug("listenPasskeyAssertion", clientId, sequenceNumber, request);
const ctx = request.context ? new Uint8Array(request.context).buffer : null;
const controller = new AbortController();
let requestId = request.context ? this.contextToRequestId(request.context) : null;
if (requestId) {
this.inFlightRequests[requestId] = controller;
}
try {
const response = await this.fido2AuthenticatorService.getAssertion(
this.convertAssertionRequest(request),
@@ -324,6 +350,11 @@ export class DesktopAutofillService implements OnDestroy {
this.logService.error("listenPasskeyAssertion error", error);
callback(error, null);
}
finally {
if (requestId) {
delete this.inFlightRequests[requestId];
}
}
});
// Listen for native status messages
@@ -340,6 +371,18 @@ export class DesktopAutofillService implements OnDestroy {
// perform ad-hoc sync
await this.adHocSync();
}
if (status.key === "cancel-operation" && status.value) {
const requestId = status.value
const controller = this.inFlightRequests[requestId]
if (controller) {
this.logService.debug(`Cancelling request ${requestId}`);
controller.abort("Operation cancelled")
}
else {
this.logService.debug(`Unknown request: ${requestId}`);
}
}
});
ipc.autofill.listenLockStatusQuery(async (clientId, sequenceNumber, request, callback) => {
@@ -457,6 +500,12 @@ export class DesktopAutofillService implements OnDestroy {
};
}
private contextToRequestId(context: number[]): string {
const buf = new Uint8Array(context).buffer;
const requestId = Utils.fromBufferToB64(buf);
return requestId
}
ngOnDestroy(): void {
this.destroy$.next();
this.destroy$.complete();

View File

@@ -78,6 +78,7 @@ export class DesktopFido2UserInterfaceService
this.router,
this.desktopSettingsService,
nativeWindowObject,
abortController,
transactionContext,
);
@@ -95,6 +96,7 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi
private router: Router,
private desktopSettingsService: DesktopSettingsService,
private windowObject: NativeWindowObject,
private abortController: AbortController,
private transactionContext: ArrayBuffer,
) {}
@@ -177,15 +179,22 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi
private async waitForUiChosenCipher(
timeoutMs: number = 60000,
): Promise<{ cipherId?: string; userVerified: boolean } | undefined> {
const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation();
try {
return await lastValueFrom(this.chosenCipherSubject.pipe(timeout(timeoutMs)));
} catch {
this.abortController.signal.throwIfAborted();
const confirmPromise = lastValueFrom(this.chosenCipherSubject.pipe(timeout(timeoutMs)));
return await Promise.race([confirmPromise, cancelPromise]);
} catch (e) {
// If we hit a timeout, return undefined instead of throwing
this.logService.debug("Timed out or cancelled?", e);
this.logService.warning("Timeout: User did not select a cipher within the allowed time", {
timeoutMs,
});
return { cipherId: undefined, userVerified: false };
}
finally {
this.unsusbscribeCancellation(abortFn);
}
}
/**
@@ -204,7 +213,19 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi
* @returns
*/
private async waitForUiNewCredentialConfirmation(): Promise<boolean> {
return lastValueFrom(this.confirmCredentialSubject);
const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation();
try {
this.abortController.signal.throwIfAborted();
const confirmPromise = lastValueFrom(this.confirmCredentialSubject);
return await Promise.race([confirmPromise, cancelPromise]);
} catch (e) {
// If we hit a timeout, return undefined instead of throwing
this.logService.debug("Timed out or cancelled?", e);
return undefined;
}
finally {
this.unsusbscribeCancellation(abortFn);
}
}
/**
@@ -375,17 +396,21 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi
await this.showUi("/lock", this.windowObject.windowXy, true, true);
let status2: AuthenticationStatus;
const { promise: cancelPromise, listener: abortFn } = this.subscribeToCancellation();
try {
status2 = await lastValueFrom(
status2 = await Promise.race([lastValueFrom(
this.authService.activeAccountStatus$.pipe(
filter((s) => s === AuthenticationStatus.Unlocked),
take(1),
timeout(1000 * 60 * 5), // 5 minutes
),
);
), cancelPromise]);
} catch (error) {
this.logService.warning("Error while waiting for vault to unlock", error);
}
finally {
this.unsusbscribeCancellation(abortFn);
}
if (status2 === AuthenticationStatus.Unlocked) {
await this.router.navigate(["/"]);
@@ -405,4 +430,24 @@ export class DesktopFido2UserInterfaceSession implements Fido2UserInterfaceSessi
async close() {
this.logService.debug("close");
}
subscribeToCancellation() {
let cancelReject: (reason?: any) => void;
const cancelPromise: Promise<never> = new Promise((_, reject) => {
cancelReject = reject
});
const abortFn = (ev: Event) => {
if (ev.target instanceof AbortSignal) {
cancelReject(ev.target.reason)
}
};
this.abortController.signal.addEventListener("abort", abortFn, { once: true });
return { promise: cancelPromise, listener: abortFn };
}
unsusbscribeCancellation(listener: (ev: Event) => void): void {
this.abortController.signal.removeEventListener("abort", listener);
}
}

View File

@@ -460,10 +460,8 @@ export class Fido2AuthenticatorService<ParentWindowReference>
}
private async findCredentialsByRp(rpId: string): Promise<CipherView[]> {
this.logService.debug("[findCredentialByRp]:", rpId)
const activeUserId = await firstValueFrom(this.accountService.activeAccount$.pipe(getUserId));
const ciphers = await this.cipherService.getAllDecrypted(activeUserId);
this.logService.debug("[findCredentialsByRp] ciphers:", ciphers)
return ciphers.filter(
(cipher) =>
!cipher.isDeleted &&