1
0
mirror of https://github.com/bitwarden/browser synced 2026-02-07 04:03:29 +00:00

Remove old get_assertion implementation

This commit is contained in:
Isaiah Inuwa
2025-11-20 10:23:52 -06:00
parent 1621b28406
commit 6ce04191ed
3 changed files with 72 additions and 501 deletions

View File

@@ -1,20 +1,7 @@
use serde_json;
use std::{
alloc::{alloc, Layout},
ptr,
sync::Arc,
time::Duration,
};
use windows::core::{s, HRESULT};
use std::{sync::Arc, time::Duration};
use crate::util::{delay_load, wstr_to_string};
use crate::webauthn::WEBAUTHN_CREDENTIAL_LIST;
use crate::{
com_provider::{
parse_credential_list, WebAuthnPluginOperationRequest, WebAuthnPluginOperationResponse,
},
ipc2::PasskeyAssertionWithoutUserInterfaceRequest,
};
use crate::ipc2::PasskeyAssertionWithoutUserInterfaceRequest;
use crate::{
ipc2::{
PasskeyAssertionRequest, PasskeyAssertionResponse, Position, TimedCallback,
@@ -23,105 +10,69 @@ use crate::{
win_webauthn::{ErrorKind, HwndExt, PluginGetAssertionRequest, WinWebAuthnError},
};
// Windows API types for WebAuthn (from webauthn.h.sample)
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST {
pub dwVersion: u32,
pub pwszRpId: *const u16, // PCWSTR
pub cbRpId: u32,
pub pbRpId: *const u8,
pub cbClientDataHash: u32,
pub pbClientDataHash: *const u8,
pub CredentialList: WEBAUTHN_CREDENTIAL_LIST,
pub cbCborExtensionsMap: u32,
pub pbCborExtensionsMap: *const u8,
pub pAuthenticatorOptions: *const crate::webauthn::WebAuthnCtapCborAuthenticatorOptions,
// Add other fields as needed...
}
pub fn get_assertion(
ipc_client: &WindowsProviderClient,
request: PluginGetAssertionRequest,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
// Extract RP information
let rp_id = request.rp_id().to_string();
pub type PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST = *mut WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST;
// Extract client data hash
let client_data_hash = request.client_data_hash().to_vec();
// Windows API function signatures for decoding get assertion requests
type WebAuthNDecodeGetAssertionRequestFn = unsafe extern "stdcall" fn(
cbEncoded: u32,
pbEncoded: *const u8,
ppGetAssertionRequest: *mut PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST,
) -> HRESULT;
// Extract user verification requirement from authenticator options
let user_verification = match request.authenticator_options().user_verification() {
Some(true) => UserVerification::Required,
Some(false) => UserVerification::Discouraged,
None => UserVerification::Preferred,
};
type WebAuthNFreeDecodedGetAssertionRequestFn =
unsafe extern "stdcall" fn(pGetAssertionRequest: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST);
// Extract allowed credentials from credential list
let allowed_credential_ids: Vec<Vec<u8>> = request
.credential_list()
.iter()
.filter_map(|cred| cred.credential_id())
.map(|id| id.to_vec())
.collect();
// RAII wrapper for decoded get assertion request
pub struct DecodedGetAssertionRequest {
ptr: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST,
free_fn: Option<WebAuthNFreeDecodedGetAssertionRequestFn>,
}
let transaction_id = request.transaction_id.to_u128().to_le_bytes().to_vec();
let client_pos = request
.window_handle
.center_position()
.unwrap_or((640, 480));
impl DecodedGetAssertionRequest {
fn new(
ptr: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST,
free_fn: Option<WebAuthNFreeDecodedGetAssertionRequestFn>,
) -> Self {
Self { ptr, free_fn }
}
pub fn as_ref(&self) -> &WEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST {
unsafe { &*self.ptr }
}
}
impl Drop for DecodedGetAssertionRequest {
fn drop(&mut self) {
if !self.ptr.is_null() {
if let Some(free_fn) = self.free_fn {
tracing::debug!("Freeing decoded get assertion request");
unsafe {
free_fn(self.ptr);
}
}
}
}
}
// Function to decode get assertion request using Windows API
unsafe fn decode_get_assertion_request(
encoded_request: &[u8],
) -> Result<DecodedGetAssertionRequest, String> {
tracing::debug!("Attempting to decode get assertion request using Windows API");
// Load the Windows WebAuthn API function
let decode_fn: Option<WebAuthNDecodeGetAssertionRequestFn> =
delay_load(s!("webauthn.dll"), s!("WebAuthNDecodeGetAssertionRequest"));
let decode_fn =
decode_fn.ok_or("Failed to load WebAuthNDecodeGetAssertionRequest from webauthn.dll")?;
// Load the free function
let free_fn: Option<WebAuthNFreeDecodedGetAssertionRequestFn> = delay_load(
s!("webauthn.dll"),
s!("WebAuthNFreeDecodedGetAssertionRequest"),
tracing::debug!(
"Get assertion request - RP: {}, Allowed credentials: {:?}",
rp_id,
allowed_credential_ids
);
let mut pp_get_assertion_request: PWEBAUTHN_CTAPCBOR_GET_ASSERTION_REQUEST = ptr::null_mut();
// Send assertion request
let assertion_request = PasskeyAssertionRequest {
rp_id,
client_data_hash,
allowed_credentials: allowed_credential_ids,
user_verification,
window_xy: Position {
x: client_pos.0,
y: client_pos.1,
},
context: transaction_id,
};
let passkey_response = send_assertion_request(ipc_client, assertion_request)
.map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?;
tracing::debug!("Assertion response received: {:?}", passkey_response);
let result = decode_fn(
encoded_request.len() as u32,
encoded_request.as_ptr(),
&mut pp_get_assertion_request,
);
// Create proper WebAuthn response from passkey_response
tracing::debug!("Creating WebAuthn get assertion response");
if result.is_err() || pp_get_assertion_request.is_null() {
return Err(format!(
"WebAuthNDecodeGetAssertionRequest failed with HRESULT: {}",
result.0
));
}
Ok(DecodedGetAssertionRequest::new(
pp_get_assertion_request,
free_fn,
))
let response = create_get_assertion_response(
passkey_response.credential_id,
passkey_response.authenticator_data,
passkey_response.signature,
passkey_response.user_handle,
)?;
Ok(response)
}
/// Helper for assertion requests
@@ -237,241 +188,9 @@ fn create_get_assertion_response(
Ok(cbor_data)
}
unsafe fn write_response(
cbor_data: &[u8],
) -> Result<*mut WebAuthnPluginOperationResponse, HRESULT> {
// Allocate memory for the response data
let response_len = cbor_data.len();
let layout = Layout::from_size_align(response_len, 1).map_err(|_| HRESULT(-1))?;
let response_ptr = alloc(layout);
if response_ptr.is_null() {
return Err(HRESULT(-1));
}
// Copy response data
ptr::copy_nonoverlapping(cbor_data.as_ptr(), response_ptr, response_len);
// Allocate memory for the response structure
let response_layout = Layout::new::<WebAuthnPluginOperationResponse>();
let operation_response_ptr = alloc(response_layout) as *mut WebAuthnPluginOperationResponse;
if operation_response_ptr.is_null() {
return Err(HRESULT(-1));
}
// Initialize the response
ptr::write(
operation_response_ptr,
WebAuthnPluginOperationResponse {
encoded_response_byte_count: response_len as u32,
encoded_response_pointer: response_ptr,
},
);
Ok(operation_response_ptr)
}
pub fn get_assertion(
ipc_client: &WindowsProviderClient,
request: PluginGetAssertionRequest,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
// Extract RP information
let rp_id = request.rp_id().to_string();
// Extract client data hash
let client_data_hash = request.client_data_hash().to_vec();
// Extract user verification requirement from authenticator options
let user_verification = match request.authenticator_options().user_verification() {
Some(true) => UserVerification::Required,
Some(false) => UserVerification::Discouraged,
None => UserVerification::Preferred,
};
// Extract allowed credentials from credential list
let allowed_credential_ids: Vec<Vec<u8>> = request
.credential_list()
.iter()
.filter_map(|cred| cred.credential_id())
.map(|id| id.to_vec())
.collect();
let transaction_id = request.transaction_id.to_u128().to_le_bytes().to_vec();
let client_pos = request
.window_handle
.center_position()
.unwrap_or((640, 480));
tracing::debug!(
"Get assertion request - RP: {}, Allowed credentials: {:?}",
rp_id,
allowed_credential_ids
);
// Send assertion request
let assertion_request = PasskeyAssertionRequest {
rp_id,
client_data_hash,
allowed_credentials: allowed_credential_ids,
user_verification,
window_xy: Position {
x: client_pos.0,
y: client_pos.1,
},
context: transaction_id,
};
let passkey_response = send_assertion_request(ipc_client, assertion_request)
.map_err(|err| format!("Failed to get assertion response from IPC channel: {err}"))?;
tracing::debug!("Assertion response received: {:?}", passkey_response);
// Create proper WebAuthn response from passkey_response
tracing::debug!("Creating WebAuthn get assertion response");
let response = create_get_assertion_response(
passkey_response.credential_id,
passkey_response.authenticator_data,
passkey_response.signature,
passkey_response.user_handle,
)?;
Ok(response)
}
/// Implementation of PluginGetAssertion moved from com_provider.rs
pub unsafe fn plugin_get_assertion(
ipc_client: &WindowsProviderClient,
request: *const WebAuthnPluginOperationRequest,
response: *mut WebAuthnPluginOperationResponse,
) -> Result<(), HRESULT> {
tracing::debug!("PluginGetAssertion() called");
// Validate input parameters
if request.is_null() || response.is_null() {
tracing::debug!("Invalid parameters passed to PluginGetAssertion");
return Err(HRESULT(-1));
}
let req = &*request;
let transaction_id = format!("{:?}", req.transaction_id);
let coords = req.window_coordinates().unwrap_or((400, 400));
tracing::debug!("Get assertion request - Transaction: {}", transaction_id);
if req.encoded_request_byte_count == 0 || req.encoded_request_pointer.is_null() {
tracing::error!("No encoded request data provided");
return Err(HRESULT(-1));
}
let encoded_request_slice = std::slice::from_raw_parts(
req.encoded_request_pointer,
req.encoded_request_byte_count as usize,
);
// Try to decode the request using Windows API
let decoded_wrapper = decode_get_assertion_request(encoded_request_slice).map_err(|err| {
tracing::debug!("Failed to decode get assertion request: {err}");
HRESULT(-1)
})?;
let decoded_request = decoded_wrapper.as_ref();
tracing::debug!("Successfully decoded get assertion request using Windows API");
// Extract RP information
let rpid = if decoded_request.pwszRpId.is_null() {
tracing::error!("RP ID is null");
return Err(HRESULT(-1));
} else {
match wstr_to_string(decoded_request.pwszRpId) {
Ok(id) => id,
Err(e) => {
tracing::error!("Failed to decode RP ID: {}", e);
return Err(HRESULT(-1));
}
}
};
// Extract client data hash
let client_data_hash =
if decoded_request.cbClientDataHash == 0 || decoded_request.pbClientDataHash.is_null() {
tracing::error!("Client data hash is required for assertion");
return Err(HRESULT(-1));
} else {
let hash_slice = std::slice::from_raw_parts(
decoded_request.pbClientDataHash,
decoded_request.cbClientDataHash as usize,
);
hash_slice.to_vec()
};
// Extract user verification requirement from authenticator options
let user_verification = if !decoded_request.pAuthenticatorOptions.is_null() {
let auth_options = &*decoded_request.pAuthenticatorOptions;
match auth_options.user_verification {
1 => UserVerification::Required,
-1 => UserVerification::Discouraged,
0 | _ => UserVerification::Preferred, // Default or undefined
}
} else {
UserVerification::Preferred // Default or undefined
};
// Extract allowed credentials from credential list
let allowed_credentials = parse_credential_list(&decoded_request.CredentialList);
// Create Windows assertion request
let transaction_id = req.transaction_id.to_u128().to_le_bytes().to_vec();
let assertion_request = PasskeyAssertionRequest {
rp_id: rpid.clone(),
client_data_hash,
allowed_credentials: allowed_credentials.clone(),
user_verification,
window_xy: Position {
x: coords.0,
y: coords.1,
},
context: transaction_id,
};
tracing::debug!(
"Get assertion request - RP: {}, Allowed credentials: {:?}",
rpid,
allowed_credentials
);
// Send assertion request
let passkey_response =
send_assertion_request(ipc_client, assertion_request).map_err(|err| {
tracing::error!("Assertion request failed: {err}");
HRESULT(-1)
})?;
tracing::debug!("Assertion response received: {:?}", passkey_response);
// Create proper WebAuthn response from passkey_response
tracing::debug!("Creating WebAuthn get assertion response");
let webauthn_response = create_get_assertion_response(
passkey_response.credential_id,
passkey_response.authenticator_data,
passkey_response.signature,
passkey_response.user_handle,
)
.map_err(|err| {
tracing::error!("Failed to encode WebAuthn assertion response as CBOR: {err}");
HRESULT(-1)
})
.and_then(|cbor| write_response(&cbor))
.map_err(|err| {
tracing::error!("Failed to create WebAuthn assertion response: {err}");
HRESULT(-1)
})?;
tracing::debug!("Successfully created WebAuthn assertion response");
(*response).encoded_response_byte_count = (*webauthn_response).encoded_response_byte_count;
(*response).encoded_response_pointer = (*webauthn_response).encoded_response_pointer;
Ok(())
}
#[cfg(test)]
mod tests {
use std::ptr::slice_from_raw_parts;
use super::{create_get_assertion_response, write_response};
use super::create_get_assertion_response;
#[test]
fn test_create_native_assertion_response() {
@@ -479,21 +198,14 @@ mod tests {
let authenticator_data = vec![5, 6, 7, 8];
let signature = vec![9, 10, 11, 12];
let user_handle = vec![13, 14, 15, 16];
let slice = unsafe {
let cbor = create_get_assertion_response(
credential_id,
authenticator_data,
signature,
user_handle,
)
.unwrap();
let response = *write_response(&cbor).unwrap();
&*slice_from_raw_parts(
response.encoded_response_pointer,
response.encoded_response_byte_count as usize,
)
};
let cbor = create_get_assertion_response(
credential_id,
authenticator_data,
signature,
user_handle,
)
.unwrap();
// CTAP2_OK, Map(5 elements)
assert_eq!([0x00, 0xa5], slice[..2]);
assert_eq!([0x00, 0xa5], cbor[..2]);
}
}

View File

@@ -1,14 +1,6 @@
use std::sync::Arc;
use std::time::Duration;
use windows::core::{implement, interface, IInspectable, IUnknown, Interface, HRESULT};
use windows::Win32::Foundation::{RECT, S_OK};
use windows::Win32::System::Com::*;
use windows::Win32::Foundation::RECT;
use windows::Win32::UI::WindowsAndMessaging::GetWindowRect;
use crate::assert::plugin_get_assertion;
use crate::ipc2::{TimedCallback, WindowsProviderClient};
use crate::make_credential::plugin_make_credential;
use crate::webauthn::WEBAUTHN_CREDENTIAL_LIST;
/// Plugin request type enum as defined in the IDL
@@ -18,6 +10,7 @@ pub enum WebAuthnPluginRequestType {
CTAP2_CBOR = 0x01,
}
/*
/// Plugin lock status enum as defined in the IDL
#[repr(u32)]
#[derive(Debug, Copy, Clone)]
@@ -25,6 +18,7 @@ pub enum PluginLockStatus {
PluginLocked = 0,
PluginUnlocked = 1,
}
*/
/// Used when creating and asserting credentials.
/// Header File Name: _WEBAUTHN_PLUGIN_OPERATION_REQUEST
@@ -65,34 +59,6 @@ pub struct WebAuthnPluginOperationResponse {
pub encoded_response_pointer: *mut u8,
}
/// Used to cancel an operation.
/// Header File Name: _WEBAUTHN_PLUGIN_CANCEL_OPERATION_REQUEST
/// Header File Usage: CancelOperation()
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct WebAuthnPluginCancelOperationRequest {
pub transaction_id: windows::core::GUID,
pub request_signature_byte_count: u32,
pub request_signature_pointer: *mut u8,
}
// Stable IPluginAuthenticator interface
#[interface("d26bcf6f-b54c-43ff-9f06-d5bf148625f7")]
pub unsafe trait IPluginAuthenticator: windows::core::IUnknown {
fn MakeCredential(
&self,
request: *const WebAuthnPluginOperationRequest,
response: *mut WebAuthnPluginOperationResponse,
) -> HRESULT;
fn GetAssertion(
&self,
request: *const WebAuthnPluginOperationRequest,
response: *mut WebAuthnPluginOperationResponse,
) -> HRESULT;
fn CancelOperation(&self, request: *const WebAuthnPluginCancelOperationRequest) -> HRESULT;
fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT;
}
pub unsafe fn parse_credential_list(credential_list: &WEBAUTHN_CREDENTIAL_LIST) -> Vec<Vec<u8>> {
let mut allowed_credentials = Vec::new();
@@ -144,112 +110,3 @@ pub unsafe fn parse_credential_list(credential_list: &WEBAUTHN_CREDENTIAL_LIST)
);
allowed_credentials
}
#[implement(IPluginAuthenticator)]
pub struct PluginAuthenticatorComObject {
client: WindowsProviderClient,
}
#[implement(IClassFactory)]
pub struct Factory;
impl IPluginAuthenticator_Impl for PluginAuthenticatorComObject_Impl {
unsafe fn MakeCredential(
&self,
request: *const WebAuthnPluginOperationRequest,
response: *mut WebAuthnPluginOperationResponse,
) -> HRESULT {
tracing::debug!("MakeCredential() called");
tracing::debug!("version2");
// Convert to legacy format for internal processing
if request.is_null() || response.is_null() {
tracing::debug!("MakeCredential: Invalid request or response pointers passed");
return HRESULT(-1);
}
let response = match plugin_make_credential(&self.client, request, response) {
Ok(()) => S_OK,
Err(err) => err,
};
tracing::debug!("MakeCredential() completed");
response
}
unsafe fn GetAssertion(
&self,
request: *const WebAuthnPluginOperationRequest,
response: *mut WebAuthnPluginOperationResponse,
) -> HRESULT {
tracing::debug!("GetAssertion() called");
if request.is_null() || response.is_null() {
return HRESULT(-1);
}
match plugin_get_assertion(&self.client, request, response) {
Ok(()) => S_OK,
Err(err) => err,
}
}
unsafe fn CancelOperation(
&self,
_request: *const WebAuthnPluginCancelOperationRequest,
) -> HRESULT {
tracing::debug!("CancelOperation() called");
HRESULT(0)
}
unsafe fn GetLockStatus(&self, lock_status: *mut PluginLockStatus) -> HRESULT {
tracing::debug!(
"GetLockStatus() called <PID {}, Thread {:?}>",
std::process::id(),
std::thread::current().id()
);
if lock_status.is_null() {
return HRESULT(-2147024809); // E_INVALIDARG
}
*lock_status = PluginLockStatus::PluginLocked;
let callback = Arc::new(TimedCallback::new());
self.client.get_lock_status(callback.clone());
match callback.wait_for_response(Duration::from_secs(3)) {
Ok(Ok(response)) => {
let status = if response.is_unlocked {
PluginLockStatus::PluginUnlocked
} else {
PluginLockStatus::PluginLocked
};
tracing::debug!("GetLockStatus() received {status:?}");
*lock_status = status;
HRESULT(0)
}
Ok(Err(err)) => {
tracing::error!("GetLockStatus() call failed: {err}");
HRESULT(-1)
}
Err(_) => {
tracing::error!("GetLockStatus() call timed out");
HRESULT(-1)
}
}
}
}
impl IClassFactory_Impl for Factory_Impl {
fn CreateInstance(
&self,
_outer: windows::core::Ref<IUnknown>,
iid: *const windows::core::GUID,
object: *mut *mut core::ffi::c_void,
) -> windows::core::Result<()> {
tracing::debug!("Creating COM server instance.");
tracing::debug!("Trying to connect to Bitwarden IPC");
let client = WindowsProviderClient::connect();
tracing::debug!("Connected to Bitwarden IPC");
let unknown: IInspectable = PluginAuthenticatorComObject { client }.into(); // TODO: IUnknown ?
unsafe { unknown.query(iid, object).ok() }
}
fn LockServer(&self, _lock: windows::core::BOOL) -> windows::core::Result<()> {
Ok(())
}
}

View File

@@ -1,6 +1,6 @@
//! Types and functions defined in the Windows WebAuthn API.
use std::{collections::HashSet, ptr::NonNull};
use std::{collections::HashSet, mem::MaybeUninit, ptr::NonNull};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use ciborium::Value;
@@ -10,7 +10,9 @@ use windows::{
};
use windows_core::{s, PCWSTR};
use crate::win_webauthn::{util::WindowsString, Clsid, ErrorKind, WinWebAuthnError};
use crate::win_webauthn::{
com::ComBuffer, util::WindowsString, Clsid, ErrorKind, WinWebAuthnError,
};
macro_rules! webauthn_call {
($symbol:literal as fn $fn_name:ident($($arg:ident: $arg_type:ty),+) -> $result_type:ty) => (