diff --git a/apps/desktop/desktop_native/macos_provider/src/assertion.rs b/apps/desktop/desktop_native/macos_provider/src/assertion.rs index c5b43bb87fa..ea8954d2f3b 100644 --- a/apps/desktop/desktop_native/macos_provider/src/assertion.rs +++ b/apps/desktop/desktop_native/macos_provider/src/assertion.rs @@ -7,12 +7,12 @@ use crate::{BitwardenError, Callback, Position, UserVerification}; #[derive(uniffi::Record, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PasskeyAssertionRequest { - rp_id: String, - client_data_hash: Vec, - user_verification: UserVerification, - allowed_credentials: Vec>, - window_xy: Position, - //extension_input: Vec, TODO: Implement support for extensions + pub(crate) rp_id: String, + pub(crate) client_data_hash: Vec, + pub(crate) user_verification: UserVerification, + pub(crate) allowed_credentials: Vec>, + pub(crate) window_xy: Position, + // pub(crate) extension_input: Vec, TODO: Implement support for extensions } #[derive(uniffi::Record, Debug, Serialize, Deserialize)] diff --git a/apps/desktop/desktop_native/macos_provider/src/lib.rs b/apps/desktop/desktop_native/macos_provider/src/lib.rs index 0ef734449ed..a1199cf959f 100644 --- a/apps/desktop/desktop_native/macos_provider/src/lib.rs +++ b/apps/desktop/desktop_native/macos_provider/src/lib.rs @@ -8,7 +8,7 @@ use std::{ }; use futures::FutureExt; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Sender; use tracing::{error, info}; use tracing_subscriber::{ @@ -21,6 +21,7 @@ uniffi::setup_scaffolding!(); mod assertion; mod registration; +mod user_verification; use assertion::{ PasskeyAssertionRequest, PasskeyAssertionWithoutUserInterfaceRequest, @@ -28,6 +29,8 @@ use assertion::{ }; use registration::{PasskeyRegistrationRequest, PreparePasskeyRegistrationCallback}; +use crate::user_verification::{UserVerificationRequest, UserVerificationResponse}; + static INIT: Once = Once::new(); #[derive(uniffi::Enum, Debug, Serialize, Deserialize)] @@ -161,8 +164,8 @@ impl MacOSProviderClient { connection_status.store(false, std::sync::atomic::Ordering::Relaxed); } Ok(SerializedMessage::HostRequest(message)) => { - let request_id = message.request_id; - tracing::debug!(%request_id, "Received request"); + let sequence_number = message.sequence_number; + tracing::debug!(%sequence_number, "Received request"); if let Err(err) = host_request_handler_tx.send(message).await { tracing::error!( "Failed to pass message to host request handler: {err}" @@ -207,7 +210,7 @@ impl MacOSProviderClient { pub fn send_native_status(&self, key: String, value: String) { let status = NativeStatus { key, value }; - self.send_message(status, None); + self.send_message(ExtensionRequest::NativeStatus(status), None); } pub fn prepare_passkey_registration( @@ -215,7 +218,10 @@ impl MacOSProviderClient { request: PasskeyRegistrationRequest, callback: Arc, ) { - self.send_message(request, Some(Box::new(callback))); + self.send_message( + ExtensionRequest::PasskeyRegistration(request), + Some(Box::new(callback)), + ); } pub fn prepare_passkey_assertion( @@ -223,7 +229,10 @@ impl MacOSProviderClient { request: PasskeyAssertionRequest, callback: Arc, ) { - self.send_message(request, Some(Box::new(callback))); + self.send_message( + ExtensionRequest::PasskeyAssertion(request), + Some(Box::new(callback)), + ); } pub fn prepare_passkey_assertion_without_user_interface( @@ -231,7 +240,10 @@ impl MacOSProviderClient { request: PasskeyAssertionWithoutUserInterfaceRequest, callback: Arc, ) { - self.send_message(request, Some(Box::new(callback))); + self.send_message( + ExtensionRequest::PasskeyAssertionWithoutUserInterface(request), + Some(Box::new(callback)), + ); } pub fn get_connection_status(&self) -> ConnectionStatus { @@ -246,17 +258,37 @@ impl MacOSProviderClient { } } -#[serde(tag = "command", rename_all = "camelCase")] +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "command", content = "params", rename_all = "camelCase")] enum CommandMessage { Connected, Disconnected, } +/// Requests from the extension to the host. +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct ExtensionRequestMessage { + sequence_number: u32, + #[serde(flatten)] + value: ExtensionRequest, +} + +/// Requests from the extension to the host. +#[derive(Serialize)] +#[serde(tag = "request", content = "params", rename_all = "camelCase")] +enum ExtensionRequest { + NativeStatus(NativeStatus), + PasskeyAssertion(PasskeyAssertionRequest), + PasskeyAssertionWithoutUserInterface(PasskeyAssertionWithoutUserInterfaceRequest), + PasskeyRegistration(PasskeyRegistrationRequest), +} + /// Requests from the host to the provider. #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct HostRequestMessage { - request_id: u32, + sequence_number: u32, #[serde(flatten)] request: HostRequest, } @@ -264,25 +296,21 @@ struct HostRequestMessage { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "request", content = "params", rename_all = "camelCase")] enum HostRequest { - #[serde(rename_all = "camelCase")] - UserVerification { - transaction_id: u32, - display_hint: String, - username: Option, - }, + UserVerification(UserVerificationRequest), } #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct HostResponseMessage { - request_id: u32, - response: HostResponse, + sequence_number: u32, + #[serde(flatten)] + response: Result, } #[derive(Serialize, Deserialize)] -#[serde(tag = "request", content = "value", rename_all = "camelCase")] +#[serde(tag = "response", content = "value", rename_all = "camelCase")] enum HostResponse { - UserVerification { user_verified: bool }, + UserVerification(UserVerificationResponse), } #[derive(Serialize, Deserialize)] @@ -290,6 +318,7 @@ enum HostResponse { enum SerializedMessage { Command(CommandMessage), HostRequest(HostRequestMessage), + #[serde(rename_all = "camelCase")] Message { sequence_number: u32, value: Result, @@ -312,22 +341,19 @@ impl MacOSProviderClient { } #[allow(clippy::unwrap_used)] - fn send_message( - &self, - message: impl Serialize + DeserializeOwned, - callback: Option>, - ) { + fn send_message(&self, message: ExtensionRequest, callback: Option>) { let sequence_number = if let Some(callback) = callback { self.add_callback(callback) } else { NO_CALLBACK_INDICATOR }; - let message = serde_json::to_string(&SerializedMessage::Message { + let message = serde_json::to_string(&ExtensionRequestMessage { sequence_number, - value: Ok(serde_json::to_value(message).unwrap()), + value: message, }) .expect("Can't serialize message"); + tracing::debug!(%message, "Sending message to host"); if let Err(e) = self.to_server_send.blocking_send(message) { // Make sure we remove the callback from the queue if we can't send the message @@ -349,36 +375,44 @@ impl MacOSProviderClient { /// Handles requests from the host to the provider. async fn handle_host_request(to_server_send: &Sender, message: HostRequestMessage) { - let request_id = message.request_id; + let sequence_number = message.sequence_number; let response = match message.request { uv_request @ HostRequest::UserVerification { .. } => { tracing::debug!("Received UV request: {uv_request:?}"); - HostResponse::UserVerification { + Ok(HostResponse::UserVerification(UserVerificationResponse { user_verified: true, - } + })) } }; let message = serde_json::to_string(&HostResponseMessage { - request_id, + sequence_number, response, }) .expect("Can't serialize message"); if let Err(e) = to_server_send.send(message).await { - error!(%request_id, "Could not send response back to host: {e}"); + error!(%sequence_number, "Could not send response back to host: {e}"); } } #[cfg(test)] mod tests { - use crate::HostRequest; + use serde_json::Value; + + use crate::{ + assertion::PasskeyAssertionRequest, + registration::PasskeyRegistrationRequest, + user_verification::{UserVerificationRequest, UserVerificationResponse}, + ExtensionRequest, ExtensionRequestMessage, HostRequest, HostResponse, HostResponseMessage, + Position, + }; use super::{HostRequestMessage, SerializedMessage}; #[test] fn test_deserialize_host_request() { let json = r#"{ - "requestId": 1, + "sequenceNumber": 1, "request": "userVerification", "params": { "transactionId": 0, @@ -391,12 +425,74 @@ mod tests { assert!(matches!( message, SerializedMessage::HostRequest(HostRequestMessage { - request_id: 1, - request: HostRequest::UserVerification { + sequence_number: 1, + request: HostRequest::UserVerification(UserVerificationRequest { transaction_id: 0, .. - }, + }), }), )); } + + #[test] + fn test_serialize_host_response() { + let message = HostResponseMessage { + sequence_number: 7, + response: Ok(HostResponse::UserVerification(UserVerificationResponse { + user_verified: true, + })), + }; + let json = serde_json::to_string(&message).unwrap(); + let value: Value = serde_json::from_str(&json).unwrap(); + assert_eq!(value["sequenceNumber"], 7); + assert_eq!(value["Ok"]["response"], "userVerification"); + assert_eq!(value["Ok"]["value"]["userVerified"], true); + } + + #[test] + fn test_serialize_extension_request() { + let message = ExtensionRequestMessage { + sequence_number: 42, + value: ExtensionRequest::PasskeyAssertion(PasskeyAssertionRequest { + rp_id: "example.com".to_string(), + client_data_hash: vec![1; 32], + user_verification: crate::UserVerification::Preferred, + allowed_credentials: vec![vec![4; 8]], + window_xy: Position { x: 100, y: 200 }, + }), + }; + let json = serde_json::to_string(&message).unwrap(); + let value: Value = serde_json::from_str(&json).unwrap(); + assert_eq!(value["sequenceNumber"], 42); + assert_eq!(value["request"], "passkeyAssertion"); + let request: PasskeyAssertionRequest = + serde_json::from_value(value.as_object().unwrap().get("params").unwrap().clone()) + .unwrap(); + assert_eq!(request.rp_id, "example.com"); + } + + #[test] + fn test_deserialize_extension_response() { + let json = r#"{ + "sequenceNumber": 1, + "value": { + "Ok": { + "rpId": "webauthn.io", + "clientDataHash": [156, 40, 76, 228, 28, 215, 79, 194, 237, 160, 250, 176, 57, 185, 247, 83, 247, 175, 218, 126, 161, 115, 202, 31, 71, 77, 49, 113, 197, 203, 88, 90], + "credentialId": [68, 161, 162, 129, 42, 70, 71, 239, 163, 98, 224, 14, 37, 190, 19, 70], + "attestationObject": [163, 99, 102, 109, 116, 100, 110, 111, 110, 101, 103, 97, 116, 116, 83, 116, 109, 116, 160, 104, 97, 117, 116, 104, 68, 97, 116, 97, 88, 148, 116, 166, 234, 146, 19, 201, 156, 47, 116, 178, 36, 146, 179, 32, 207, 64, 38, 42, 148, 193, 169, 80, 160, 57, 127, 41, 37, 11, 96, 132, 30, 240, 93, 0, 0, 0, 0, 213, 72, 130, 110, 121, 180, 219, 64, 163, 216, 17, 17, 111, 126, 131, 73, 0, 16, 68, 161, 162, 129, 42, 70, 71, 239, 163, 98, 224, 14, 37, 190, 19, 70, 165, 1, 2, 3, 38, 32, 1, 33, 88, 32, 204, 69, 19, 156, 78, 44, 190, 84, 242, 39, 36, 208, 150, 253, 237, 217, 249, 181, 225, 233, 218, 51, 252, 30, 63, 228, 232, 116, 70, 69, 107, 137, 34, 88, 32, 102, 120, 26, 113, 188, 129, 247, 29, 166, 195, 112, 151, 177, 248, 83, 120, 132, 188, 128, 160, 113, 89, 2, 141, 8, 190, 110, 6, 220, 5, 181, 96] + } + } + }"#; + let message = serde_json::from_str::(&json).unwrap(); + if let SerializedMessage::Message { + sequence_number: 1, + value: Ok(value), + } = message + { + assert_eq!(value["rpId"], "webauthn.io"); + } else { + panic!("Does not match"); + } + } } diff --git a/apps/desktop/desktop_native/macos_provider/src/user_verification.rs b/apps/desktop/desktop_native/macos_provider/src/user_verification.rs new file mode 100644 index 00000000000..596d8b31e34 --- /dev/null +++ b/apps/desktop/desktop_native/macos_provider/src/user_verification.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct UserVerificationRequest { + pub(crate) transaction_id: u32, + pub(crate) display_hint: String, + pub(crate) username: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct UserVerificationResponse { + pub(crate) user_verified: bool, +} diff --git a/apps/desktop/desktop_native/napi/index.d.ts b/apps/desktop/desktop_native/napi/index.d.ts index 75510b53a78..1956bdc020a 100644 --- a/apps/desktop/desktop_native/napi/index.d.ts +++ b/apps/desktop/desktop_native/napi/index.d.ts @@ -21,8 +21,6 @@ export declare namespace autofill { /** Prompt a user for user verification using OS APIs. */ verifyUser(request: UserVerificationRequest): Promise } - export type HostResponse = - | { type: 'UserVerification', field0: UserVerificationResponse } export interface NativeStatus { key: string value: string diff --git a/apps/desktop/desktop_native/napi/src/lib.rs b/apps/desktop/desktop_native/napi/src/lib.rs index 7d93e0a466a..0ce59eba24d 100644 --- a/apps/desktop/desktop_native/napi/src/lib.rs +++ b/apps/desktop/desktop_native/napi/src/lib.rs @@ -645,7 +645,7 @@ pub mod autofill { bindgen_prelude::FnArgs, threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}, }; - use serde::{de::DeserializeOwned, Deserialize, Serialize}; + use serde::{Deserialize, Serialize}; use tokio::sync::oneshot; use tracing::error; @@ -692,12 +692,29 @@ pub mod autofill { pub is_focused: bool, pub handle: String, } + #[derive(Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct ExtensionRequestMessage { + pub sequence_number: u32, + #[serde(flatten)] + pub request: ExtensionRequest, + } #[derive(Serialize, Deserialize)] - #[serde(bound = "T: Serialize + DeserializeOwned")] - pub struct PasskeyMessage { - pub sequence_number: u32, - pub value: Result, + #[serde(tag = "request", content = "params", rename_all = "camelCase")] + pub enum ExtensionRequest { + NativeStatus(NativeStatus), + PasskeyAssertion(PasskeyAssertionRequest), + PasskeyAssertionWithoutUserInterface(PasskeyAssertionWithoutUserInterfaceRequest), + PasskeyRegistration(PasskeyRegistrationRequest), + WindowHandle(WindowHandleQueryRequest), + } + + #[derive(Serialize)] + #[serde(bound = "T: Serialize", rename_all = "camelCase")] + struct ExtensionResponse { + sequence_number: u32, + value: Result, } #[derive(Serialize, Deserialize)] @@ -709,11 +726,13 @@ pub mod autofill { #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct HostRequestMessage { - request_id: u32, + sequence_number: u32, #[serde(flatten)] request: HostRequest, } + // TODO: HOST RESPONSE MESSAGE + #[napi(object)] #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -814,17 +833,35 @@ pub mod autofill { pub user_verified: bool, } - #[napi] - pub enum HostResponse { + #[derive(Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + struct HostResponseMessage { + sequence_number: u32, + #[serde(flatten)] + response: Result, + } + + #[derive(Serialize, Deserialize)] + #[serde(tag = "response", content = "value", rename_all = "camelCase")] + enum HostResponse { UserVerification(UserVerificationResponse), } + #[derive(Deserialize)] + #[serde(untagged)] + enum IpcMessage { + HostResponse(HostResponseMessage), + ExtensionRequest(ExtensionRequestMessage), + } + + type HostCallbacks = Arc>>>>; + #[napi] pub struct AutofillIpcServer { server: desktop_core::ipc::server::Server, // We need to keep track of the callbacks so we can call them when we receive a response host_callbacks_counter: AtomicU32, - host_callbacks: Arc>>>>, + host_callbacks: HostCallbacks, } // FIXME: Remove unwraps! They panic and terminate the whole application. @@ -872,6 +909,8 @@ pub mod autofill { >, ) -> napi::Result { let (send, mut recv) = tokio::sync::mpsc::channel::(32); + let host_callbacks: HostCallbacks = Arc::new(Mutex::new(HashMap::new())); + let host_callbacks2 = host_callbacks.clone(); tokio::spawn(async move { while let Some(Message { client_id, @@ -888,94 +927,75 @@ pub mod autofill { continue; }; - match serde_json::from_str::>( - &message, - ) { - Ok(msg) => { - let value = msg - .value - .map(|value| (client_id, msg.sequence_number, value).into()) - .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); - - window_handle_query_callback - .call(value, ThreadsafeFunctionCallMode::NonBlocking); - continue; - } - Err(e) => { - tracing::warn!(error = %e, "Could not deserialize request as WindowHandleQueryRequest. Trying other types..."); - } - } - - match serde_json::from_str::>( - &message, - ) { - Ok(msg) => { - let value = msg - .value - .map(|value| (client_id, msg.sequence_number, value).into()) - .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); - - assertion_callback - .call(value, ThreadsafeFunctionCallMode::NonBlocking); - continue; - } - Err(e) => { - error!(error = %e, "Error deserializing as PasskeyAssertionRequest"); - } - } - - match serde_json::from_str::< - PasskeyMessage, - >(&message) - { - Ok(msg) => { - let value = msg - .value - .map(|value| (client_id, msg.sequence_number, value).into()) - .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); - - assertion_without_user_interface_callback - .call(value, ThreadsafeFunctionCallMode::NonBlocking); - continue; - } - Err(e) => { - error!(error = %e, "Error deserializing as PasskeyAssertionWithoutUserInterfaceRequest"); - } - } - - match serde_json::from_str::>( - &message, - ) { - Ok(msg) => { - let value = msg - .value - .map(|value| (client_id, msg.sequence_number, value).into()) - .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); - registration_callback - .call(value, ThreadsafeFunctionCallMode::NonBlocking); - continue; - } - Err(e) => { - error!(error = %e, "Error deserializing PasskeyRegistrationRequest"); - } - } - - match serde_json::from_str::>(&message) { - Ok(msg) => { - let value = msg - .value - .map(|value| (client_id, msg.sequence_number, value)) - .map_err(|e| napi::Error::from_reason(format!("{e:?}"))); - native_status_callback - .call(value, ThreadsafeFunctionCallMode::NonBlocking); - continue; - } + let msg = match serde_json::from_str::(&message) { + Ok(msg) => msg, Err(error) => { - error!(%error, "Unable to deserialize as native status."); + error!( + %error, + %message, + "Received an unknown message from extension" + ); + continue; } - } + }; - error!(message, "Received an unknown message"); + match msg { + IpcMessage::HostResponse(msg) => { + if let Some(tx) = host_callbacks2 + .lock() + .expect("not poisoned") + .remove(&msg.sequence_number) + { + if let Err(err) = tx.send(msg.response) { + tracing::error!(""); + } + } + } + IpcMessage::ExtensionRequest(msg) => match msg.request { + ExtensionRequest::PasskeyAssertion(assertion_request) => { + let params = + (client_id, msg.sequence_number, assertion_request); + assertion_callback.call( + Ok(params.into()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + ExtensionRequest::PasskeyAssertionWithoutUserInterface( + assertion_request, + ) => { + let params = + (client_id, msg.sequence_number, assertion_request); + assertion_without_user_interface_callback.call( + Ok(params.into()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + ExtensionRequest::PasskeyRegistration(assertion_request) => { + let params = + (client_id, msg.sequence_number, assertion_request); + registration_callback.call( + Ok(params.into()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + ExtensionRequest::WindowHandle(window_handle_request) => { + let params = + (client_id, msg.sequence_number, window_handle_request); + window_handle_query_callback.call( + Ok(params.into()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + ExtensionRequest::NativeStatus(status_request) => { + let params = + (client_id, msg.sequence_number, status_request); + native_status_callback.call( + Ok(params.into()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + }, + } } } } @@ -993,7 +1013,7 @@ pub mod autofill { server, // Start at 1 since 0 is reserved for "no callback" scenarios host_callbacks_counter: AtomicU32::new(NO_CALLBACK_INDICATOR + 1), - host_callbacks: Arc::new(Mutex::new(HashMap::new())), + host_callbacks, }) } @@ -1017,7 +1037,7 @@ pub mod autofill { sequence_number: u32, response: PasskeyRegistrationResponse, ) -> napi::Result { - let message = PasskeyMessage { + let message = ExtensionResponse { sequence_number, value: Ok(response), }; @@ -1031,7 +1051,7 @@ pub mod autofill { sequence_number: u32, response: PasskeyAssertionResponse, ) -> napi::Result { - let message = PasskeyMessage { + let message = ExtensionResponse { sequence_number, value: Ok(response), }; @@ -1045,7 +1065,7 @@ pub mod autofill { sequence_number: u32, response: WindowHandleQueryResponse, ) -> napi::Result { - let message = PasskeyMessage { + let message = ExtensionResponse { sequence_number, value: Ok(response), }; @@ -1059,7 +1079,7 @@ pub mod autofill { sequence_number: u32, error: String, ) -> napi::Result { - let message: PasskeyMessage<()> = PasskeyMessage { + let message: ExtensionResponse<()> = ExtensionResponse { sequence_number, value: Err(BitwardenError::Internal(error)), }; @@ -1072,18 +1092,18 @@ pub mod autofill { &self, request: UserVerificationRequest, ) -> napi::Result { - let request_id = self + let sequence_number = self .host_callbacks_counter .fetch_add(1, std::sync::atomic::Ordering::SeqCst); let (tx, rx) = oneshot::channel(); let command = HostRequestMessage { - request_id, + sequence_number, request: HostRequest::UserVerification(request), }; self.host_callbacks .lock() .expect("not poisoned") - .insert(request_id, tx); + .insert(sequence_number, tx); let json = serde_json::to_string(&command).expect("serde to serialize"); tracing::debug!(json, "Sending verify user message"); self.send(0, json)?; @@ -1104,6 +1124,7 @@ pub mod autofill { // TODO: Add a way to send a message to a specific client? fn send(&self, _client_id: u32, message: String) -> napi::Result { + tracing::debug!(%message, "Sending message to extension"); self.server .send(message) .map_err(|e| {