1
0
mirror of https://github.com/bitwarden/browser synced 2026-02-06 19:53:59 +00:00

Wire up bidirectional communication on autofill IPC server

This commit is contained in:
Isaiah Inuwa
2026-01-09 07:30:10 -06:00
parent 5de23bd24f
commit cb38151cac
2 changed files with 135 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ use std::{
use futures::FutureExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::sync::mpsc::Sender;
use tracing::{error, info};
use tracing_subscriber::{
filter::{EnvFilter, LevelFilter},
@@ -114,6 +115,8 @@ impl MacOSProviderClient {
let (from_server_send, mut from_server_recv) = tokio::sync::mpsc::channel(32);
let (to_server_send, to_server_recv) = tokio::sync::mpsc::channel(32);
let (host_request_handler_tx, mut host_request_handler_rx) = tokio::sync::mpsc::channel(32);
let to_server_send2 = to_server_send.clone();
let client = MacOSProviderClient {
to_server_send,
@@ -139,8 +142,15 @@ impl MacOSProviderClient {
.map(|r| r.map_err(|e| e.to_string())),
);
rt.spawn(async move {
while let Some(message) = host_request_handler_rx.recv().await {
handle_host_request(&to_server_send2, message).await;
}
});
rt.block_on(async move {
while let Some(message) = from_server_recv.recv().await {
tracing::debug!(?message, "Received message");
match serde_json::from_str::<SerializedMessage>(&message) {
Ok(SerializedMessage::Command(CommandMessage::Connected)) => {
info!("Connected to server");
@@ -150,6 +160,15 @@ impl MacOSProviderClient {
info!("Disconnected from server");
connection_status.store(false, std::sync::atomic::Ordering::Relaxed);
}
Ok(SerializedMessage::HostRequest(message)) => {
let request_id = message.request_id;
tracing::debug!(%request_id, "Received request");
if let Err(err) = host_request_handler_tx.send(message).await {
tracing::error!(
"Failed to pass message to host request handler: {err}"
);
}
}
Ok(SerializedMessage::Message {
sequence_number,
value,
@@ -227,17 +246,50 @@ impl MacOSProviderClient {
}
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "command", rename_all = "camelCase")]
enum CommandMessage {
Connected,
Disconnected,
}
/// Requests from the host to the provider.
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostRequestMessage {
request_id: u32,
#[serde(flatten)]
request: HostRequest,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
enum HostRequest {
#[serde(rename_all = "camelCase")]
UserVerification {
transaction_id: u32,
display_hint: String,
username: Option<String>,
},
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostResponseMessage {
request_id: u32,
response: HostResponse,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "request", content = "value", rename_all = "camelCase")]
enum HostResponse {
UserVerification { user_verified: bool },
}
#[derive(Serialize, Deserialize)]
#[serde(untagged, rename_all = "camelCase")]
enum SerializedMessage {
Command(CommandMessage),
HostRequest(HostRequestMessage),
Message {
sequence_number: u32,
value: Result<serde_json::Value, BitwardenError>,
@@ -294,3 +346,57 @@ impl MacOSProviderClient {
}
}
}
/// Handles requests from the host to the provider.
async fn handle_host_request(to_server_send: &Sender<String>, message: HostRequestMessage) {
let request_id = message.request_id;
let response = match message.request {
uv_request @ HostRequest::UserVerification { .. } => {
tracing::debug!("Received UV request: {uv_request:?}");
HostResponse::UserVerification {
user_verified: true,
}
}
};
let message = serde_json::to_string(&HostResponseMessage {
request_id,
response,
})
.expect("Can't serialize message");
if let Err(e) = to_server_send.send(message).await {
error!(%request_id, "Could not send response back to host: {e}");
}
}
#[cfg(test)]
mod tests {
use crate::HostRequest;
use super::{HostRequestMessage, SerializedMessage};
#[test]
fn test_deserialize_host_request() {
let json = r#"{
"requestId": 1,
"request": "userVerification",
"params": {
"transactionId": 0,
"displayHint": "Verify it's you to overwrite a credential",
"username": "bw-ii-plugin1"
}
}"#;
let value = serde_json::from_str::<serde_json::Value>(&json).unwrap();
let message = serde_json::from_value::<SerializedMessage>(value).unwrap();
assert!(matches!(
message,
SerializedMessage::HostRequest(HostRequestMessage {
request_id: 1,
request: HostRequest::UserVerification {
transaction_id: 0,
..
},
}),
));
}
}

View File

@@ -641,8 +641,13 @@ pub mod autofill {
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::sync::oneshot;
use tracing::error;
/// In our callback management, 0 is a reserved sequence number indicating that a message does not
/// have a callback.
const NO_CALLBACK_INDICATOR: u32 = 0;
#[napi]
pub async fn run_command(value: String) -> napi::Result<String> {
desktop_core::autofill::run_command(value)
@@ -690,6 +695,20 @@ pub mod autofill {
pub value: Result<T, BitwardenError>,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
pub enum HostRequest {
UserVerification(UserVerificationRequest),
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct HostRequestMessage {
request_id: u32,
#[serde(flatten)]
request: HostRequest,
}
#[napi(object)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -777,6 +796,9 @@ pub mod autofill {
#[napi]
pub struct AutofillIpcServer {
server: desktop_core::ipc::server::Server,
// We need to keep track of the callbacks so we can call them when we receive a response
host_callbacks_counter: AtomicU32,
host_callbacks: Arc<Mutex<HashMap<u32, oneshot::Sender<Result<HostResponse, String>>>>>,
}
// FIXME: Remove unwraps! They panic and terminate the whole application.
@@ -941,7 +963,12 @@ pub mod autofill {
))
})?;
Ok(AutofillIpcServer { server })
Ok(AutofillIpcServer {
server,
// Start at 1 since 0 is reserved for "no callback" scenarios
host_callbacks_counter: AtomicU32::new(NO_CALLBACK_INDICATOR + 1),
host_callbacks: Arc::new(Mutex::new(HashMap::new())),
})
}
/// Return the path to the IPC server.