diff --git a/apps/desktop/desktop_native/napi/index.d.ts b/apps/desktop/desktop_native/napi/index.d.ts index 531d77777f5..cde03c95005 100644 --- a/apps/desktop/desktop_native/napi/index.d.ts +++ b/apps/desktop/desktop_native/napi/index.d.ts @@ -158,6 +158,12 @@ export declare namespace autofill { export interface LockStatusQueryResponse { isUnlocked: boolean } + export interface WindowHandleQueryRequest { + windowHandle: string + } + export interface WindowHandleQueryResponse { + handle: string + } export interface Position { x: number y: number @@ -221,7 +227,7 @@ export declare namespace autofill { * connection and must be the same for both the server and client. @param callback * This function will be called whenever a message is received from a client. */ - static listen(name: string, registrationCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyRegistrationRequest) => void, assertionCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyAssertionRequest) => void, assertionWithoutUserInterfaceCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyAssertionWithoutUserInterfaceRequest) => void, nativeStatusCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: NativeStatus) => void, lockStatusQueryCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: LockStatusQueryRequest) => void): Promise + static listen(name: string, registrationCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyRegistrationRequest) => void, assertionCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyAssertionRequest) => void, assertionWithoutUserInterfaceCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: PasskeyAssertionWithoutUserInterfaceRequest) => void, nativeStatusCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: NativeStatus) => void, lockStatusQueryCallback: (error: null | Error, clientId: number, sequenceNumber: number, message: LockStatusQueryRequest) => void, windowHandleQueryCallback: (err: Error | null, arg0: number, arg1: number, arg2: WindowHandleQueryRequest) => any): Promise /** Return the path to the IPC server. */ getPath(): string /** Stop the IPC server. */ @@ -229,6 +235,7 @@ export declare namespace autofill { completeRegistration(clientId: number, sequenceNumber: number, response: PasskeyRegistrationResponse): number completeAssertion(clientId: number, sequenceNumber: number, response: PasskeyAssertionResponse): number completeLockStatusQuery(clientId: number, sequenceNumber: number, response: LockStatusQueryResponse): number + completeWindowHandleQuery(clientId: number, sequenceNumber: number, response: WindowHandleQueryResponse): number completeError(clientId: number, sequenceNumber: number, error: string): number } } diff --git a/apps/desktop/desktop_native/napi/src/lib.rs b/apps/desktop/desktop_native/napi/src/lib.rs index 0307cb3eb1d..fea199a7806 100644 --- a/apps/desktop/desktop_native/napi/src/lib.rs +++ b/apps/desktop/desktop_native/napi/src/lib.rs @@ -672,6 +672,20 @@ pub mod autofill { pub is_unlocked: bool, } + #[napi(object)] + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct WindowHandleQueryRequest { + pub window_handle: String, + } + + #[napi(object)] + #[derive(Debug, Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct WindowHandleQueryResponse { + pub handle: String, + } + #[derive(Serialize, Deserialize)] #[serde(bound = "T: Serialize + DeserializeOwned")] pub struct PasskeyMessage { @@ -818,6 +832,10 @@ pub mod autofill { (u32, u32, LockStatusQueryRequest), ErrorStrategy::CalleeHandled, >, + window_handle_query_callback: ThreadsafeFunction< + (u32, u32, WindowHandleQueryRequest), + ErrorStrategy::CalleeHandled, + >, ) -> napi::Result { let (send, mut recv) = tokio::sync::mpsc::channel::(32); tokio::spawn(async move { @@ -836,6 +854,24 @@ pub mod autofill { continue; }; + match serde_json::from_str::>( + &message, + ) { + Ok(msg) => { + let value = msg + .value + .map(|value| (client_id, msg.sequence_number, value)) + .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); + + window_handle_query_callback + .call(value, ThreadsafeFunctionCallMode::NonBlocking); + continue; + } + Err(e) => { + tracing::warn!(error = %e, "Could not deserialize request as WindowHandleQueryRequest. Trying other types..."); + } + } + match serde_json::from_str::>( &message, ) { @@ -994,6 +1030,20 @@ pub mod autofill { self.send(client_id, serde_json::to_string(&message).unwrap()) } + #[napi] + pub fn complete_window_handle_query( + &self, + client_id: u32, + sequence_number: u32, + response: WindowHandleQueryResponse, + ) -> napi::Result { + let message = PasskeyMessage { + sequence_number, + value: Ok(response), + }; + self.send(client_id, serde_json::to_string(&message).unwrap()) + } + #[napi] pub fn complete_error( &self, diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs index a95e466f848..713166b6250 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/mod.rs @@ -17,8 +17,12 @@ use tracing::{error, info}; mod assertion; mod lock_status; mod registration; +mod window_handle_query; -use crate::ipc2::lock_status::{GetLockStatusCallback, LockStatusRequest}; +use crate::ipc2::{ + lock_status::{GetLockStatusCallback, LockStatusRequest}, + window_handle_query::{GetWindowHandleQueryCallback, WindowHandleQueryRequest}, +}; pub use assertion::{ PasskeyAssertionRequest, PasskeyAssertionResponse, PasskeyAssertionWithoutUserInterfaceRequest, PreparePasskeyAssertionCallback, @@ -26,6 +30,7 @@ pub use assertion::{ pub use registration::{ PasskeyRegistrationRequest, PasskeyRegistrationResponse, PreparePasskeyRegistrationCallback, }; +pub use window_handle_query::WindowHandleQueryResponse; #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -209,6 +214,13 @@ impl WindowsProviderClient { self.send_message(LockStatusRequest {}, Some(Box::new(callback))); } + pub fn get_window_handle(&self, callback: Arc) { + self.send_message( + WindowHandleQueryRequest::default(), + Some(Box::new(callback)), + ); + } + pub fn get_connection_status(&self) -> ConnectionStatus { let is_connected = self .connection_status diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/window_handle_query.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/window_handle_query.rs new file mode 100644 index 00000000000..8e0ed0663ea --- /dev/null +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/ipc2/window_handle_query.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::ipc2::{BitwardenError, Callback, TimedCallback}; + +#[derive(Debug, Default, Serialize, Deserialize)] +pub(super) struct WindowHandleQueryRequest { + #[serde(rename = "windowHandle")] + window_handle: String, +} + +#[derive(Debug, Deserialize)] +pub struct WindowHandleQueryResponse { + #[serde(deserialize_with = "crate::util::deserialize_b64")] + pub(crate) handle: Vec, +} + +impl Callback for Arc { + fn complete(&self, response: serde_json::Value) -> Result<(), serde_json::Error> { + let response = serde_json::from_value(response)?; + self.as_ref().on_complete(response); + Ok(()) + } + + fn error(&self, error: BitwardenError) { + self.as_ref().on_error(error); + } +} + +pub trait GetWindowHandleQueryCallback: Send + Sync { + fn on_complete(&self, response: WindowHandleQueryResponse); + fn on_error(&self, error: BitwardenError); +} + +impl GetWindowHandleQueryCallback for TimedCallback { + fn on_complete(&self, response: WindowHandleQueryResponse) { + self.send(Ok(response)); + } + + fn on_error(&self, error: BitwardenError) { + self.send(Err(error)) + } +} diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs index 9fc2c59b34f..4e13df17a00 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/lib.rs @@ -11,6 +11,7 @@ mod util; use std::{ collections::{HashMap, HashSet}, + mem::MaybeUninit, sync::{ mpsc::{self, Sender}, Arc, Mutex, @@ -26,6 +27,13 @@ use win_webauthn::{ }, AuthenticatorInfo, CtapVersion, PublicKeyCredentialParameters, }; +use windows::Win32::{ + Foundation::HWND, + System::Threading::{AttachThreadInput, GetCurrentThreadId}, + UI::WindowsAndMessaging::{ + AllowSetForegroundWindow, BringWindowToTop, GetForegroundWindow, GetWindowThreadProcessId, + }, +}; use windows_core::GUID; use crate::ipc2::{ConnectionStatus, TimedCallback, WindowsProviderClient}; @@ -127,17 +135,65 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { ) -> Result, Box> { tracing::debug!("Received MakeCredential: {request:?}"); let client = self.get_client(); + + let plugin_hwnd = get_window_handle(&client)?; + unsafe { + // tracing::debug!( + // "Setting window {plugin_hwnd:?} as child of {:?}", + // client_hwnd + // ); + // if let Err(err) = SetParent(plugin_hwnd, Some(client_hwnd)) { + // tracing::warn!( + // "Failed to set {plugin_hwnd:?} as child of {:?}: {err}", + // request.window_handle + // ) + // }; + + let dw_current_thread = GetCurrentThreadId(); + let dw_fg_thread = GetWindowThreadProcessId(GetForegroundWindow(), None); + let result = AttachThreadInput(dw_current_thread, dw_fg_thread, true); + tracing::debug!("AttachThreadInput() - attach? {result:?}"); + // let result = SetForegroundWindow(plugin_hwnd); + // tracing::debug!("SetForegroundWindow? {result:?}"); + // let result = SetFocus(Some(plugin_hwnd)); + // tracing::debug!("SetFocus? {result:?}"); + // let result = SetActiveWindow(plugin_hwnd); + // tracing::debug!("Set active window? {result:?}"); + // let result = EnableWindow(plugin_hwnd, true); + // tracing::debug!("EnableWindow? {result:?}"); + let result = BringWindowToTop(plugin_hwnd); + tracing::debug!("BringWindowToTop? {result:?}"); + + // let result = SwitchToThisWindow(plugin_hwnd, true); + // tracing::debug!("SwitchToThisWindow? {result:?}"); + let result = AttachThreadInput(dw_current_thread, dw_fg_thread, false); + tracing::debug!("AttachThreadInput() - detach? {result:?}"); + }; let (cancel_tx, cancel_rx) = mpsc::channel(); let transaction_id = request.transaction_id; self.callbacks .lock() .expect("not poisoned") .insert(transaction_id, cancel_tx); + let client_hwnd = request.window_handle; let response = make_credential::make_credential(&client, request, cancel_rx); self.callbacks .lock() .expect("not poisoned") .remove(&transaction_id); + unsafe { + /* + _ = SetParent(plugin_hwnd, None) + .inspect_err(|err| tracing::debug!("Failed to reset parent: {err}")); + */ + let mut client_pid = MaybeUninit::uninit(); + if GetWindowThreadProcessId(client_hwnd, Some(client_pid.as_mut_ptr())) != 0 { + let client_pid = client_pid.assume_init(); + if let Err(err) = AllowSetForegroundWindow(client_pid) { + tracing::debug!("Failed to allow client to set foreground window: {err}") + }; + } + }; response } @@ -147,6 +203,18 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { ) -> Result, Box> { tracing::debug!("Received GetAssertion: {request:?}"); let client = self.get_client(); + + let plugin_hwnd = get_window_handle(&client)?; + unsafe { + let dw_current_thread = GetCurrentThreadId(); + let dw_fg_thread = GetWindowThreadProcessId(GetForegroundWindow(), None); + let result = AttachThreadInput(dw_current_thread, dw_fg_thread, true); + tracing::debug!("AttachThreadInput() - attach? {result:?}"); + let result = BringWindowToTop(plugin_hwnd); + tracing::debug!("BringWindowToTop? {result:?}"); + let result = AttachThreadInput(dw_current_thread, dw_fg_thread, false); + tracing::debug!("AttachThreadInput() - detach? {result:?}"); + }; let (cancel_tx, cancel_rx) = mpsc::channel(); let transaction_id = request.transaction_id; self.callbacks @@ -200,3 +268,27 @@ impl PluginAuthenticator for BitwardenPluginAuthenticator { } } } + +fn get_window_handle(client: &WindowsProviderClient) -> Result { + tracing::debug!("Get Window Handle!"); + let window_handle_callback = Arc::new(TimedCallback::new()); + client.get_window_handle(window_handle_callback.clone()); + let plugin_window_handle = window_handle_callback + .wait_for_response(Duration::from_secs(3), None) + .unwrap() + .unwrap() + .handle; + unsafe { + // SAFETY: We check to make sure that the vec is the expected size + // before converting it. If the handle is invalid when passed to + // Windows, the request will be rejected. + if plugin_window_handle.len() == size_of::() { + Ok(*plugin_window_handle.as_ptr().cast()) + } else { + Err(format!( + "Invalid window handle received: {:?}", + plugin_window_handle + )) + } + } +} diff --git a/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs b/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs index 3e7f5df5f7c..297cdcf496c 100644 --- a/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs +++ b/apps/desktop/desktop_native/windows_plugin_authenticator/src/util.rs @@ -1,4 +1,5 @@ use base64::engine::{general_purpose::STANDARD, Engine as _}; +use serde::{de::Visitor, Deserializer}; use windows::{ core::GUID, Win32::{ @@ -44,3 +45,24 @@ impl HwndExt for HWND { pub fn create_context_string(transaction_id: GUID) -> String { STANDARD.encode(transaction_id.to_u128().to_le_bytes().to_vec()) } + +pub fn deserialize_b64<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + deserializer.deserialize_str(Base64Visitor {}) +} + +struct Base64Visitor; +impl<'de> Visitor<'de> for Base64Visitor { + type Value = Vec; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("A valid base64 string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + use base64::{engine::general_purpose::STANDARD, Engine as _}; + STANDARD.decode(v).map_err(|err| E::custom(err)) + } +} diff --git a/apps/desktop/src/autofill/preload.ts b/apps/desktop/src/autofill/preload.ts index 361b1d21777..8fc62d229da 100644 --- a/apps/desktop/src/autofill/preload.ts +++ b/apps/desktop/src/autofill/preload.ts @@ -225,4 +225,42 @@ export default { }, ); }, + listenGetWindowHandle: ( + fn: ( + clientId: number, + sequenceNumber: number, + request: autofill.WindowHandleQueryRequest, + completeCallback: (error: Error | null, response: autofill.WindowHandleQueryResponse) => void, + ) => void, + ) => { + ipcRenderer.on( + "autofill.windowHandleQuery", + ( + event, + data: { + clientId: number; + sequenceNumber: number; + request: autofill.WindowHandleQueryRequest; + }, + ) => { + const { clientId, sequenceNumber, request } = data; + fn(clientId, sequenceNumber, request, (error, response) => { + if (error) { + ipcRenderer.send("autofill.completeError", { + clientId, + sequenceNumber, + error: error.message, + }); + return; + } + + ipcRenderer.send("autofill.completeWindowHandleQuery", { + clientId, + sequenceNumber, + response, + }); + }); + }, + ); + }, }; diff --git a/apps/desktop/src/autofill/services/desktop-autofill.service.ts b/apps/desktop/src/autofill/services/desktop-autofill.service.ts index d6a2cf0d5a5..1cc4aa4012a 100644 --- a/apps/desktop/src/autofill/services/desktop-autofill.service.ts +++ b/apps/desktop/src/autofill/services/desktop-autofill.service.ts @@ -400,6 +400,21 @@ export class DesktopAutofillService implements OnDestroy { callback(null, { isUnlocked }) }) + ipc.autofill.listenGetWindowHandle(async (clientId, sequenceNumber, request, callback) => { + if (!(await this.configService.getFeatureFlag(NativeCredentialSyncFeatureFlag))) { + this.logService.debug( + `listenGetWindowHandle: ${NativeCredentialSyncFeatureFlag} feature flag is disabled`, + ); + return; + } + + this.logService.debug("listenGetWindowHandle", clientId, sequenceNumber, request); + let handle = Utils.fromBufferToB64(await ipc.platform.getNativeWindowHandle()); + const response = { handle }; + this.logService.debug("listenGetWindowHandle: sending", response); + callback(null, { handle }) + }) + ipc.autofill.listenerReady(); } diff --git a/apps/desktop/src/platform/main/autofill/native-autofill.main.ts b/apps/desktop/src/platform/main/autofill/native-autofill.main.ts index 74aa188b86f..e82ee9a4bfb 100644 --- a/apps/desktop/src/platform/main/autofill/native-autofill.main.ts +++ b/apps/desktop/src/platform/main/autofill/native-autofill.main.ts @@ -151,6 +151,19 @@ export class NativeAutofillMain { request, }); }, + // WindowHandleQueryCallback + (error, clientId, sequenceNumber, request) => { + if (error) { + this.logService.error("autofill.IpcServer.windowHandleQuery", error); + this.ipcServer.completeError(clientId, sequenceNumber, String(error)); + return; + } + this.safeSend("autofill.windowHandleQuery", { + clientId, + sequenceNumber, + request, + }); + }, ); ipcMain.on("autofill.listenerReady", () => { @@ -179,6 +192,11 @@ export class NativeAutofillMain { this.ipcServer.completeLockStatusQuery(clientId, sequenceNumber, response); }); + ipcMain.on("autofill.completeWindowHandleQuery", (event, data) => { + this.logService.debug("autofill.completeWindowHandleQuery", data); + const { clientId, sequenceNumber, response } = data; + this.ipcServer.completeWindowHandleQuery(clientId, sequenceNumber, response); + }); ipcMain.on("autofill.completeError", (event, data) => {