mirror of
https://github.com/bitwarden/browser
synced 2026-02-06 19:53:59 +00:00
Wire up bidirectional communication on autofill IPC server
This commit is contained in:
@@ -9,6 +9,7 @@ use std::{
|
||||
|
||||
use futures::FutureExt;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::{
|
||||
filter::{EnvFilter, LevelFilter},
|
||||
@@ -114,6 +115,8 @@ impl MacOSProviderClient {
|
||||
|
||||
let (from_server_send, mut from_server_recv) = tokio::sync::mpsc::channel(32);
|
||||
let (to_server_send, to_server_recv) = tokio::sync::mpsc::channel(32);
|
||||
let (host_request_handler_tx, mut host_request_handler_rx) = tokio::sync::mpsc::channel(32);
|
||||
let to_server_send2 = to_server_send.clone();
|
||||
|
||||
let client = MacOSProviderClient {
|
||||
to_server_send,
|
||||
@@ -139,8 +142,15 @@ impl MacOSProviderClient {
|
||||
.map(|r| r.map_err(|e| e.to_string())),
|
||||
);
|
||||
|
||||
rt.spawn(async move {
|
||||
while let Some(message) = host_request_handler_rx.recv().await {
|
||||
handle_host_request(&to_server_send2, message).await;
|
||||
}
|
||||
});
|
||||
|
||||
rt.block_on(async move {
|
||||
while let Some(message) = from_server_recv.recv().await {
|
||||
tracing::debug!(?message, "Received message");
|
||||
match serde_json::from_str::<SerializedMessage>(&message) {
|
||||
Ok(SerializedMessage::Command(CommandMessage::Connected)) => {
|
||||
info!("Connected to server");
|
||||
@@ -150,6 +160,15 @@ impl MacOSProviderClient {
|
||||
info!("Disconnected from server");
|
||||
connection_status.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
Ok(SerializedMessage::HostRequest(message)) => {
|
||||
let request_id = message.request_id;
|
||||
tracing::debug!(%request_id, "Received request");
|
||||
if let Err(err) = host_request_handler_tx.send(message).await {
|
||||
tracing::error!(
|
||||
"Failed to pass message to host request handler: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(SerializedMessage::Message {
|
||||
sequence_number,
|
||||
value,
|
||||
@@ -227,17 +246,50 @@ impl MacOSProviderClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "command", rename_all = "camelCase")]
|
||||
enum CommandMessage {
|
||||
Connected,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
/// Requests from the host to the provider.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct HostRequestMessage {
|
||||
request_id: u32,
|
||||
#[serde(flatten)]
|
||||
request: HostRequest,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
|
||||
enum HostRequest {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
UserVerification {
|
||||
transaction_id: u32,
|
||||
display_hint: String,
|
||||
username: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct HostResponseMessage {
|
||||
request_id: u32,
|
||||
response: HostResponse,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "request", content = "value", rename_all = "camelCase")]
|
||||
enum HostResponse {
|
||||
UserVerification { user_verified: bool },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(untagged, rename_all = "camelCase")]
|
||||
enum SerializedMessage {
|
||||
Command(CommandMessage),
|
||||
HostRequest(HostRequestMessage),
|
||||
Message {
|
||||
sequence_number: u32,
|
||||
value: Result<serde_json::Value, BitwardenError>,
|
||||
@@ -294,3 +346,57 @@ impl MacOSProviderClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles requests from the host to the provider.
|
||||
async fn handle_host_request(to_server_send: &Sender<String>, message: HostRequestMessage) {
|
||||
let request_id = message.request_id;
|
||||
let response = match message.request {
|
||||
uv_request @ HostRequest::UserVerification { .. } => {
|
||||
tracing::debug!("Received UV request: {uv_request:?}");
|
||||
HostResponse::UserVerification {
|
||||
user_verified: true,
|
||||
}
|
||||
}
|
||||
};
|
||||
let message = serde_json::to_string(&HostResponseMessage {
|
||||
request_id,
|
||||
response,
|
||||
})
|
||||
.expect("Can't serialize message");
|
||||
|
||||
if let Err(e) = to_server_send.send(message).await {
|
||||
error!(%request_id, "Could not send response back to host: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::HostRequest;
|
||||
|
||||
use super::{HostRequestMessage, SerializedMessage};
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_host_request() {
|
||||
let json = r#"{
|
||||
"requestId": 1,
|
||||
"request": "userVerification",
|
||||
"params": {
|
||||
"transactionId": 0,
|
||||
"displayHint": "Verify it's you to overwrite a credential",
|
||||
"username": "bw-ii-plugin1"
|
||||
}
|
||||
}"#;
|
||||
let value = serde_json::from_str::<serde_json::Value>(&json).unwrap();
|
||||
let message = serde_json::from_value::<SerializedMessage>(value).unwrap();
|
||||
assert!(matches!(
|
||||
message,
|
||||
SerializedMessage::HostRequest(HostRequestMessage {
|
||||
request_id: 1,
|
||||
request: HostRequest::UserVerification {
|
||||
transaction_id: 0,
|
||||
..
|
||||
},
|
||||
}),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,8 +641,13 @@ pub mod autofill {
|
||||
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::error;
|
||||
|
||||
/// In our callback management, 0 is a reserved sequence number indicating that a message does not
|
||||
/// have a callback.
|
||||
const NO_CALLBACK_INDICATOR: u32 = 0;
|
||||
|
||||
#[napi]
|
||||
pub async fn run_command(value: String) -> napi::Result<String> {
|
||||
desktop_core::autofill::run_command(value)
|
||||
@@ -690,6 +695,20 @@ pub mod autofill {
|
||||
pub value: Result<T, BitwardenError>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "request", content = "params", rename_all = "camelCase")]
|
||||
pub enum HostRequest {
|
||||
UserVerification(UserVerificationRequest),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct HostRequestMessage {
|
||||
request_id: u32,
|
||||
#[serde(flatten)]
|
||||
request: HostRequest,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -777,6 +796,9 @@ pub mod autofill {
|
||||
#[napi]
|
||||
pub struct AutofillIpcServer {
|
||||
server: desktop_core::ipc::server::Server,
|
||||
// We need to keep track of the callbacks so we can call them when we receive a response
|
||||
host_callbacks_counter: AtomicU32,
|
||||
host_callbacks: Arc<Mutex<HashMap<u32, oneshot::Sender<Result<HostResponse, String>>>>>,
|
||||
}
|
||||
|
||||
// FIXME: Remove unwraps! They panic and terminate the whole application.
|
||||
@@ -941,7 +963,12 @@ pub mod autofill {
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(AutofillIpcServer { server })
|
||||
Ok(AutofillIpcServer {
|
||||
server,
|
||||
// Start at 1 since 0 is reserved for "no callback" scenarios
|
||||
host_callbacks_counter: AtomicU32::new(NO_CALLBACK_INDICATOR + 1),
|
||||
host_callbacks: Arc::new(Mutex::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the path to the IPC server.
|
||||
|
||||
Reference in New Issue
Block a user