diff --git a/apps/desktop/desktop_native/macos_provider/src/lib.rs b/apps/desktop/desktop_native/macos_provider/src/lib.rs index 8619a77a0f2..0ef734449ed 100644 --- a/apps/desktop/desktop_native/macos_provider/src/lib.rs +++ b/apps/desktop/desktop_native/macos_provider/src/lib.rs @@ -9,6 +9,7 @@ use std::{ use futures::FutureExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::sync::mpsc::Sender; use tracing::{error, info}; use tracing_subscriber::{ filter::{EnvFilter, LevelFilter}, @@ -114,6 +115,8 @@ impl MacOSProviderClient { let (from_server_send, mut from_server_recv) = tokio::sync::mpsc::channel(32); let (to_server_send, to_server_recv) = tokio::sync::mpsc::channel(32); + let (host_request_handler_tx, mut host_request_handler_rx) = tokio::sync::mpsc::channel(32); + let to_server_send2 = to_server_send.clone(); let client = MacOSProviderClient { to_server_send, @@ -139,8 +142,15 @@ impl MacOSProviderClient { .map(|r| r.map_err(|e| e.to_string())), ); + rt.spawn(async move { + while let Some(message) = host_request_handler_rx.recv().await { + handle_host_request(&to_server_send2, message).await; + } + }); + rt.block_on(async move { while let Some(message) = from_server_recv.recv().await { + tracing::debug!(?message, "Received message"); match serde_json::from_str::(&message) { Ok(SerializedMessage::Command(CommandMessage::Connected)) => { info!("Connected to server"); @@ -150,6 +160,15 @@ impl MacOSProviderClient { info!("Disconnected from server"); connection_status.store(false, std::sync::atomic::Ordering::Relaxed); } + Ok(SerializedMessage::HostRequest(message)) => { + let request_id = message.request_id; + tracing::debug!(%request_id, "Received request"); + if let Err(err) = host_request_handler_tx.send(message).await { + tracing::error!( + "Failed to pass message to host request handler: {err}" + ); + } + } Ok(SerializedMessage::Message { sequence_number, value, @@ -227,17 +246,50 @@ impl MacOSProviderClient { } } -#[derive(Serialize, Deserialize)] #[serde(tag = "command", rename_all = "camelCase")] enum CommandMessage { Connected, Disconnected, } +/// Requests from the host to the provider. +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct HostRequestMessage { + request_id: u32, + #[serde(flatten)] + request: HostRequest, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "request", content = "params", rename_all = "camelCase")] +enum HostRequest { + #[serde(rename_all = "camelCase")] + UserVerification { + transaction_id: u32, + display_hint: String, + username: Option, + }, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct HostResponseMessage { + request_id: u32, + response: HostResponse, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "request", content = "value", rename_all = "camelCase")] +enum HostResponse { + UserVerification { user_verified: bool }, +} + #[derive(Serialize, Deserialize)] #[serde(untagged, rename_all = "camelCase")] enum SerializedMessage { Command(CommandMessage), + HostRequest(HostRequestMessage), Message { sequence_number: u32, value: Result, @@ -294,3 +346,57 @@ impl MacOSProviderClient { } } } + +/// Handles requests from the host to the provider. +async fn handle_host_request(to_server_send: &Sender, message: HostRequestMessage) { + let request_id = message.request_id; + let response = match message.request { + uv_request @ HostRequest::UserVerification { .. } => { + tracing::debug!("Received UV request: {uv_request:?}"); + HostResponse::UserVerification { + user_verified: true, + } + } + }; + let message = serde_json::to_string(&HostResponseMessage { + request_id, + response, + }) + .expect("Can't serialize message"); + + if let Err(e) = to_server_send.send(message).await { + error!(%request_id, "Could not send response back to host: {e}"); + } +} + +#[cfg(test)] +mod tests { + use crate::HostRequest; + + use super::{HostRequestMessage, SerializedMessage}; + + #[test] + fn test_deserialize_host_request() { + let json = r#"{ + "requestId": 1, + "request": "userVerification", + "params": { + "transactionId": 0, + "displayHint": "Verify it's you to overwrite a credential", + "username": "bw-ii-plugin1" + } + }"#; + let value = serde_json::from_str::(&json).unwrap(); + let message = serde_json::from_value::(value).unwrap(); + assert!(matches!( + message, + SerializedMessage::HostRequest(HostRequestMessage { + request_id: 1, + request: HostRequest::UserVerification { + transaction_id: 0, + .. + }, + }), + )); + } +} diff --git a/apps/desktop/desktop_native/napi/src/lib.rs b/apps/desktop/desktop_native/napi/src/lib.rs index e2ef395424f..0ca1b7edb36 100644 --- a/apps/desktop/desktop_native/napi/src/lib.rs +++ b/apps/desktop/desktop_native/napi/src/lib.rs @@ -641,8 +641,13 @@ pub mod autofill { threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; + use tokio::sync::oneshot; use tracing::error; + /// In our callback management, 0 is a reserved sequence number indicating that a message does not + /// have a callback. + const NO_CALLBACK_INDICATOR: u32 = 0; + #[napi] pub async fn run_command(value: String) -> napi::Result { desktop_core::autofill::run_command(value) @@ -690,6 +695,20 @@ pub mod autofill { pub value: Result, } + #[derive(Serialize, Deserialize)] + #[serde(tag = "request", content = "params", rename_all = "camelCase")] + pub enum HostRequest { + UserVerification(UserVerificationRequest), + } + + #[derive(Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + struct HostRequestMessage { + request_id: u32, + #[serde(flatten)] + request: HostRequest, + } + #[napi(object)] #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -777,6 +796,9 @@ pub mod autofill { #[napi] pub struct AutofillIpcServer { server: desktop_core::ipc::server::Server, + // We need to keep track of the callbacks so we can call them when we receive a response + host_callbacks_counter: AtomicU32, + host_callbacks: Arc>>>>, } // FIXME: Remove unwraps! They panic and terminate the whole application. @@ -941,7 +963,12 @@ pub mod autofill { )) })?; - Ok(AutofillIpcServer { server }) + Ok(AutofillIpcServer { + server, + // Start at 1 since 0 is reserved for "no callback" scenarios + host_callbacks_counter: AtomicU32::new(NO_CALLBACK_INDICATOR + 1), + host_callbacks: Arc::new(Mutex::new(HashMap::new())), + }) } /// Return the path to the IPC server.