1
0
mirror of https://github.com/bitwarden/browser synced 2026-01-31 00:33:33 +00:00

Stricter types for Extension<->Host IPC

This commit is contained in:
Isaiah Inuwa
2026-01-09 09:04:27 -06:00
parent 969900e22f
commit 1d423e80e2
5 changed files with 276 additions and 146 deletions

View File

@@ -7,12 +7,12 @@ use crate::{BitwardenError, Callback, Position, UserVerification};
#[derive(uniffi::Record, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PasskeyAssertionRequest {
rp_id: String,
client_data_hash: Vec<u8>,
user_verification: UserVerification,
allowed_credentials: Vec<Vec<u8>>,
window_xy: Position,
//extension_input: Vec<u8>, TODO: Implement support for extensions
pub(crate) rp_id: String,
pub(crate) client_data_hash: Vec<u8>,
pub(crate) user_verification: UserVerification,
pub(crate) allowed_credentials: Vec<Vec<u8>>,
pub(crate) window_xy: Position,
// pub(crate) extension_input: Vec<u8>, TODO: Implement support for extensions
}
#[derive(uniffi::Record, Debug, Serialize, Deserialize)]

View File

@@ -8,7 +8,7 @@ use std::{
};
use futures::FutureExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Sender;
use tracing::{error, info};
use tracing_subscriber::{
@@ -21,6 +21,7 @@ uniffi::setup_scaffolding!();
mod assertion;
mod registration;
mod user_verification;
use assertion::{
PasskeyAssertionRequest, PasskeyAssertionWithoutUserInterfaceRequest,
@@ -28,6 +29,8 @@ use assertion::{
};
use registration::{PasskeyRegistrationRequest, PreparePasskeyRegistrationCallback};
use crate::user_verification::{UserVerificationRequest, UserVerificationResponse};
static INIT: Once = Once::new();
#[derive(uniffi::Enum, Debug, Serialize, Deserialize)]
@@ -161,8 +164,8 @@ impl MacOSProviderClient {
connection_status.store(false, std::sync::atomic::Ordering::Relaxed);
}
Ok(SerializedMessage::HostRequest(message)) => {
let request_id = message.request_id;
tracing::debug!(%request_id, "Received request");
let sequence_number = message.sequence_number;
tracing::debug!(%sequence_number, "Received request");
if let Err(err) = host_request_handler_tx.send(message).await {
tracing::error!(
"Failed to pass message to host request handler: {err}"
@@ -207,7 +210,7 @@ impl MacOSProviderClient {
pub fn send_native_status(&self, key: String, value: String) {
let status = NativeStatus { key, value };
self.send_message(status, None);
self.send_message(ExtensionRequest::NativeStatus(status), None);
}
pub fn prepare_passkey_registration(
@@ -215,7 +218,10 @@ impl MacOSProviderClient {
request: PasskeyRegistrationRequest,
callback: Arc<dyn PreparePasskeyRegistrationCallback>,
) {
self.send_message(request, Some(Box::new(callback)));
self.send_message(
ExtensionRequest::PasskeyRegistration(request),
Some(Box::new(callback)),
);
}
pub fn prepare_passkey_assertion(
@@ -223,7 +229,10 @@ impl MacOSProviderClient {
request: PasskeyAssertionRequest,
callback: Arc<dyn PreparePasskeyAssertionCallback>,
) {
self.send_message(request, Some(Box::new(callback)));
self.send_message(
ExtensionRequest::PasskeyAssertion(request),
Some(Box::new(callback)),
);
}
pub fn prepare_passkey_assertion_without_user_interface(
@@ -231,7 +240,10 @@ impl MacOSProviderClient {
request: PasskeyAssertionWithoutUserInterfaceRequest,
callback: Arc<dyn PreparePasskeyAssertionCallback>,
) {
self.send_message(request, Some(Box::new(callback)));
self.send_message(
ExtensionRequest::PasskeyAssertionWithoutUserInterface(request),
Some(Box::new(callback)),
);
}
pub fn get_connection_status(&self) -> ConnectionStatus {
@@ -246,17 +258,37 @@ impl MacOSProviderClient {
}
}
#[serde(tag = "command", rename_all = "camelCase")]
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "command", content = "params", rename_all = "camelCase")]
enum CommandMessage {
Connected,
Disconnected,
}
/// Requests from the extension to the host.
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct ExtensionRequestMessage {
sequence_number: u32,
#[serde(flatten)]
value: ExtensionRequest,
}
/// Requests from the extension to the host.
#[derive(Serialize)]
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
enum ExtensionRequest {
NativeStatus(NativeStatus),
PasskeyAssertion(PasskeyAssertionRequest),
PasskeyAssertionWithoutUserInterface(PasskeyAssertionWithoutUserInterfaceRequest),
PasskeyRegistration(PasskeyRegistrationRequest),
}
/// Requests from the host to the provider.
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostRequestMessage {
request_id: u32,
sequence_number: u32,
#[serde(flatten)]
request: HostRequest,
}
@@ -264,25 +296,21 @@ struct HostRequestMessage {
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
enum HostRequest {
#[serde(rename_all = "camelCase")]
UserVerification {
transaction_id: u32,
display_hint: String,
username: Option<String>,
},
UserVerification(UserVerificationRequest),
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostResponseMessage {
request_id: u32,
response: HostResponse,
sequence_number: u32,
#[serde(flatten)]
response: Result<HostResponse, String>,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "request", content = "value", rename_all = "camelCase")]
#[serde(tag = "response", content = "value", rename_all = "camelCase")]
enum HostResponse {
UserVerification { user_verified: bool },
UserVerification(UserVerificationResponse),
}
#[derive(Serialize, Deserialize)]
@@ -290,6 +318,7 @@ enum HostResponse {
enum SerializedMessage {
Command(CommandMessage),
HostRequest(HostRequestMessage),
#[serde(rename_all = "camelCase")]
Message {
sequence_number: u32,
value: Result<serde_json::Value, BitwardenError>,
@@ -312,22 +341,19 @@ impl MacOSProviderClient {
}
#[allow(clippy::unwrap_used)]
fn send_message(
&self,
message: impl Serialize + DeserializeOwned,
callback: Option<Box<dyn Callback>>,
) {
fn send_message(&self, message: ExtensionRequest, callback: Option<Box<dyn Callback>>) {
let sequence_number = if let Some(callback) = callback {
self.add_callback(callback)
} else {
NO_CALLBACK_INDICATOR
};
let message = serde_json::to_string(&SerializedMessage::Message {
let message = serde_json::to_string(&ExtensionRequestMessage {
sequence_number,
value: Ok(serde_json::to_value(message).unwrap()),
value: message,
})
.expect("Can't serialize message");
tracing::debug!(%message, "Sending message to host");
if let Err(e) = self.to_server_send.blocking_send(message) {
// Make sure we remove the callback from the queue if we can't send the message
@@ -349,36 +375,44 @@ impl MacOSProviderClient {
/// Handles requests from the host to the provider.
async fn handle_host_request(to_server_send: &Sender<String>, message: HostRequestMessage) {
let request_id = message.request_id;
let sequence_number = message.sequence_number;
let response = match message.request {
uv_request @ HostRequest::UserVerification { .. } => {
tracing::debug!("Received UV request: {uv_request:?}");
HostResponse::UserVerification {
Ok(HostResponse::UserVerification(UserVerificationResponse {
user_verified: true,
}
}))
}
};
let message = serde_json::to_string(&HostResponseMessage {
request_id,
sequence_number,
response,
})
.expect("Can't serialize message");
if let Err(e) = to_server_send.send(message).await {
error!(%request_id, "Could not send response back to host: {e}");
error!(%sequence_number, "Could not send response back to host: {e}");
}
}
#[cfg(test)]
mod tests {
use crate::HostRequest;
use serde_json::Value;
use crate::{
assertion::PasskeyAssertionRequest,
registration::PasskeyRegistrationRequest,
user_verification::{UserVerificationRequest, UserVerificationResponse},
ExtensionRequest, ExtensionRequestMessage, HostRequest, HostResponse, HostResponseMessage,
Position,
};
use super::{HostRequestMessage, SerializedMessage};
#[test]
fn test_deserialize_host_request() {
let json = r#"{
"requestId": 1,
"sequenceNumber": 1,
"request": "userVerification",
"params": {
"transactionId": 0,
@@ -391,12 +425,74 @@ mod tests {
assert!(matches!(
message,
SerializedMessage::HostRequest(HostRequestMessage {
request_id: 1,
request: HostRequest::UserVerification {
sequence_number: 1,
request: HostRequest::UserVerification(UserVerificationRequest {
transaction_id: 0,
..
},
}),
}),
));
}
#[test]
fn test_serialize_host_response() {
let message = HostResponseMessage {
sequence_number: 7,
response: Ok(HostResponse::UserVerification(UserVerificationResponse {
user_verified: true,
})),
};
let json = serde_json::to_string(&message).unwrap();
let value: Value = serde_json::from_str(&json).unwrap();
assert_eq!(value["sequenceNumber"], 7);
assert_eq!(value["Ok"]["response"], "userVerification");
assert_eq!(value["Ok"]["value"]["userVerified"], true);
}
#[test]
fn test_serialize_extension_request() {
let message = ExtensionRequestMessage {
sequence_number: 42,
value: ExtensionRequest::PasskeyAssertion(PasskeyAssertionRequest {
rp_id: "example.com".to_string(),
client_data_hash: vec![1; 32],
user_verification: crate::UserVerification::Preferred,
allowed_credentials: vec![vec![4; 8]],
window_xy: Position { x: 100, y: 200 },
}),
};
let json = serde_json::to_string(&message).unwrap();
let value: Value = serde_json::from_str(&json).unwrap();
assert_eq!(value["sequenceNumber"], 42);
assert_eq!(value["request"], "passkeyAssertion");
let request: PasskeyAssertionRequest =
serde_json::from_value(value.as_object().unwrap().get("params").unwrap().clone())
.unwrap();
assert_eq!(request.rp_id, "example.com");
}
#[test]
fn test_deserialize_extension_response() {
let json = r#"{
"sequenceNumber": 1,
"value": {
"Ok": {
"rpId": "webauthn.io",
"clientDataHash": [156, 40, 76, 228, 28, 215, 79, 194, 237, 160, 250, 176, 57, 185, 247, 83, 247, 175, 218, 126, 161, 115, 202, 31, 71, 77, 49, 113, 197, 203, 88, 90],
"credentialId": [68, 161, 162, 129, 42, 70, 71, 239, 163, 98, 224, 14, 37, 190, 19, 70],
"attestationObject": [163, 99, 102, 109, 116, 100, 110, 111, 110, 101, 103, 97, 116, 116, 83, 116, 109, 116, 160, 104, 97, 117, 116, 104, 68, 97, 116, 97, 88, 148, 116, 166, 234, 146, 19, 201, 156, 47, 116, 178, 36, 146, 179, 32, 207, 64, 38, 42, 148, 193, 169, 80, 160, 57, 127, 41, 37, 11, 96, 132, 30, 240, 93, 0, 0, 0, 0, 213, 72, 130, 110, 121, 180, 219, 64, 163, 216, 17, 17, 111, 126, 131, 73, 0, 16, 68, 161, 162, 129, 42, 70, 71, 239, 163, 98, 224, 14, 37, 190, 19, 70, 165, 1, 2, 3, 38, 32, 1, 33, 88, 32, 204, 69, 19, 156, 78, 44, 190, 84, 242, 39, 36, 208, 150, 253, 237, 217, 249, 181, 225, 233, 218, 51, 252, 30, 63, 228, 232, 116, 70, 69, 107, 137, 34, 88, 32, 102, 120, 26, 113, 188, 129, 247, 29, 166, 195, 112, 151, 177, 248, 83, 120, 132, 188, 128, 160, 113, 89, 2, 141, 8, 190, 110, 6, 220, 5, 181, 96]
}
}
}"#;
let message = serde_json::from_str::<SerializedMessage>(&json).unwrap();
if let SerializedMessage::Message {
sequence_number: 1,
value: Ok(value),
} = message
{
assert_eq!(value["rpId"], "webauthn.io");
} else {
panic!("Does not match");
}
}
}

View File

@@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct UserVerificationRequest {
pub(crate) transaction_id: u32,
pub(crate) display_hint: String,
pub(crate) username: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct UserVerificationResponse {
pub(crate) user_verified: bool,
}

View File

@@ -21,8 +21,6 @@ export declare namespace autofill {
/** Prompt a user for user verification using OS APIs. */
verifyUser(request: UserVerificationRequest): Promise<UserVerificationResponse>
}
export type HostResponse =
| { type: 'UserVerification', field0: UserVerificationResponse }
export interface NativeStatus {
key: string
value: string

View File

@@ -645,7 +645,7 @@ pub mod autofill {
bindgen_prelude::FnArgs,
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use tracing::error;
@@ -692,12 +692,29 @@ pub mod autofill {
pub is_focused: bool,
pub handle: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExtensionRequestMessage {
pub sequence_number: u32,
#[serde(flatten)]
pub request: ExtensionRequest,
}
#[derive(Serialize, Deserialize)]
#[serde(bound = "T: Serialize + DeserializeOwned")]
pub struct PasskeyMessage<T: Serialize + DeserializeOwned> {
pub sequence_number: u32,
pub value: Result<T, BitwardenError>,
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
pub enum ExtensionRequest {
NativeStatus(NativeStatus),
PasskeyAssertion(PasskeyAssertionRequest),
PasskeyAssertionWithoutUserInterface(PasskeyAssertionWithoutUserInterfaceRequest),
PasskeyRegistration(PasskeyRegistrationRequest),
WindowHandle(WindowHandleQueryRequest),
}
#[derive(Serialize)]
#[serde(bound = "T: Serialize", rename_all = "camelCase")]
struct ExtensionResponse<T> {
sequence_number: u32,
value: Result<T, BitwardenError>,
}
#[derive(Serialize, Deserialize)]
@@ -709,11 +726,13 @@ pub mod autofill {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostRequestMessage {
request_id: u32,
sequence_number: u32,
#[serde(flatten)]
request: HostRequest,
}
// TODO: HOST RESPONSE MESSAGE
#[napi(object)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -814,17 +833,35 @@ pub mod autofill {
pub user_verified: bool,
}
#[napi]
pub enum HostResponse {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostResponseMessage {
sequence_number: u32,
#[serde(flatten)]
response: Result<HostResponse, String>,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "response", content = "value", rename_all = "camelCase")]
enum HostResponse {
UserVerification(UserVerificationResponse),
}
#[derive(Deserialize)]
#[serde(untagged)]
enum IpcMessage {
HostResponse(HostResponseMessage),
ExtensionRequest(ExtensionRequestMessage),
}
type HostCallbacks = Arc<Mutex<HashMap<u32, oneshot::Sender<Result<HostResponse, String>>>>>;
#[napi]
pub struct AutofillIpcServer {
server: desktop_core::ipc::server::Server,
// We need to keep track of the callbacks so we can call them when we receive a response
host_callbacks_counter: AtomicU32,
host_callbacks: Arc<Mutex<HashMap<u32, oneshot::Sender<Result<HostResponse, String>>>>>,
host_callbacks: HostCallbacks,
}
// FIXME: Remove unwraps! They panic and terminate the whole application.
@@ -872,6 +909,8 @@ pub mod autofill {
>,
) -> napi::Result<Self> {
let (send, mut recv) = tokio::sync::mpsc::channel::<Message>(32);
let host_callbacks: HostCallbacks = Arc::new(Mutex::new(HashMap::new()));
let host_callbacks2 = host_callbacks.clone();
tokio::spawn(async move {
while let Some(Message {
client_id,
@@ -888,94 +927,75 @@ pub mod autofill {
continue;
};
match serde_json::from_str::<PasskeyMessage<WindowHandleQueryRequest>>(
&message,
) {
Ok(msg) => {
let value = msg
.value
.map(|value| (client_id, msg.sequence_number, value).into())
.map_err(|e| napi::Error::from_reason(format!("{e:?}")));
window_handle_query_callback
.call(value, ThreadsafeFunctionCallMode::NonBlocking);
continue;
}
Err(e) => {
tracing::warn!(error = %e, "Could not deserialize request as WindowHandleQueryRequest. Trying other types...");
}
}
match serde_json::from_str::<PasskeyMessage<PasskeyAssertionRequest>>(
&message,
) {
Ok(msg) => {
let value = msg
.value
.map(|value| (client_id, msg.sequence_number, value).into())
.map_err(|e| napi::Error::from_reason(format!("{e:?}")));
assertion_callback
.call(value, ThreadsafeFunctionCallMode::NonBlocking);
continue;
}
Err(e) => {
error!(error = %e, "Error deserializing as PasskeyAssertionRequest");
}
}
match serde_json::from_str::<
PasskeyMessage<PasskeyAssertionWithoutUserInterfaceRequest>,
>(&message)
{
Ok(msg) => {
let value = msg
.value
.map(|value| (client_id, msg.sequence_number, value).into())
.map_err(|e| napi::Error::from_reason(format!("{e:?}")));
assertion_without_user_interface_callback
.call(value, ThreadsafeFunctionCallMode::NonBlocking);
continue;
}
Err(e) => {
error!(error = %e, "Error deserializing as PasskeyAssertionWithoutUserInterfaceRequest");
}
}
match serde_json::from_str::<PasskeyMessage<PasskeyRegistrationRequest>>(
&message,
) {
Ok(msg) => {
let value = msg
.value
.map(|value| (client_id, msg.sequence_number, value).into())
.map_err(|e| napi::Error::from_reason(format!("{e:?}")));
registration_callback
.call(value, ThreadsafeFunctionCallMode::NonBlocking);
continue;
}
Err(e) => {
error!(error = %e, "Error deserializing PasskeyRegistrationRequest");
}
}
match serde_json::from_str::<PasskeyMessage<NativeStatus>>(&message) {
Ok(msg) => {
let value = msg
.value
.map(|value| (client_id, msg.sequence_number, value))
.map_err(|e| napi::Error::from_reason(format!("{e:?}")));
native_status_callback
.call(value, ThreadsafeFunctionCallMode::NonBlocking);
continue;
}
let msg = match serde_json::from_str::<IpcMessage>(&message) {
Ok(msg) => msg,
Err(error) => {
error!(%error, "Unable to deserialize as native status.");
error!(
%error,
%message,
"Received an unknown message from extension"
);
continue;
}
}
};
error!(message, "Received an unknown message");
match msg {
IpcMessage::HostResponse(msg) => {
if let Some(tx) = host_callbacks2
.lock()
.expect("not poisoned")
.remove(&msg.sequence_number)
{
if let Err(err) = tx.send(msg.response) {
tracing::error!("");
}
}
}
IpcMessage::ExtensionRequest(msg) => match msg.request {
ExtensionRequest::PasskeyAssertion(assertion_request) => {
let params =
(client_id, msg.sequence_number, assertion_request);
assertion_callback.call(
Ok(params.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
ExtensionRequest::PasskeyAssertionWithoutUserInterface(
assertion_request,
) => {
let params =
(client_id, msg.sequence_number, assertion_request);
assertion_without_user_interface_callback.call(
Ok(params.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
ExtensionRequest::PasskeyRegistration(assertion_request) => {
let params =
(client_id, msg.sequence_number, assertion_request);
registration_callback.call(
Ok(params.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
ExtensionRequest::WindowHandle(window_handle_request) => {
let params =
(client_id, msg.sequence_number, window_handle_request);
window_handle_query_callback.call(
Ok(params.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
ExtensionRequest::NativeStatus(status_request) => {
let params =
(client_id, msg.sequence_number, status_request);
native_status_callback.call(
Ok(params.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
},
}
}
}
}
@@ -993,7 +1013,7 @@ pub mod autofill {
server,
// Start at 1 since 0 is reserved for "no callback" scenarios
host_callbacks_counter: AtomicU32::new(NO_CALLBACK_INDICATOR + 1),
host_callbacks: Arc::new(Mutex::new(HashMap::new())),
host_callbacks,
})
}
@@ -1017,7 +1037,7 @@ pub mod autofill {
sequence_number: u32,
response: PasskeyRegistrationResponse,
) -> napi::Result<u32> {
let message = PasskeyMessage {
let message = ExtensionResponse {
sequence_number,
value: Ok(response),
};
@@ -1031,7 +1051,7 @@ pub mod autofill {
sequence_number: u32,
response: PasskeyAssertionResponse,
) -> napi::Result<u32> {
let message = PasskeyMessage {
let message = ExtensionResponse {
sequence_number,
value: Ok(response),
};
@@ -1045,7 +1065,7 @@ pub mod autofill {
sequence_number: u32,
response: WindowHandleQueryResponse,
) -> napi::Result<u32> {
let message = PasskeyMessage {
let message = ExtensionResponse {
sequence_number,
value: Ok(response),
};
@@ -1059,7 +1079,7 @@ pub mod autofill {
sequence_number: u32,
error: String,
) -> napi::Result<u32> {
let message: PasskeyMessage<()> = PasskeyMessage {
let message: ExtensionResponse<()> = ExtensionResponse {
sequence_number,
value: Err(BitwardenError::Internal(error)),
};
@@ -1072,18 +1092,18 @@ pub mod autofill {
&self,
request: UserVerificationRequest,
) -> napi::Result<UserVerificationResponse> {
let request_id = self
let sequence_number = self
.host_callbacks_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let (tx, rx) = oneshot::channel();
let command = HostRequestMessage {
request_id,
sequence_number,
request: HostRequest::UserVerification(request),
};
self.host_callbacks
.lock()
.expect("not poisoned")
.insert(request_id, tx);
.insert(sequence_number, tx);
let json = serde_json::to_string(&command).expect("serde to serialize");
tracing::debug!(json, "Sending verify user message");
self.send(0, json)?;
@@ -1104,6 +1124,7 @@ pub mod autofill {
// TODO: Add a way to send a message to a specific client?
fn send(&self, _client_id: u32, message: String) -> napi::Result<u32> {
tracing::debug!(%message, "Sending message to extension");
self.server
.send(message)
.map_err(|e| {