diff --git a/akd/crates/akd_storage/src/ms_sql/mod.rs b/akd/crates/akd_storage/src/ms_sql/mod.rs index 4c8a2e9b26..f4e83df74d 100644 --- a/akd/crates/akd_storage/src/ms_sql/mod.rs +++ b/akd/crates/akd_storage/src/ms_sql/mod.rs @@ -25,16 +25,17 @@ use tables::{ temp_table::TempTable, values, }; +use uuid::Uuid; use crate::{ ms_sql::tables::{ - akd_storable_for_ms_sql::QueryStatement, publish_queue::{ - bulk_delete_rows, bulk_delete_statement, enqueue_statement, peek_statement, + bulk_delete_rows, bulk_delete_statement, enqueue_statement, peek_no_limit_statement, + peek_statement, }, vrf_key, }, - publish_queue::{PublishQueue, PublishQueueError, PublishQueueItem}, + publish_queue::{PublishQueue, PublishQueueError}, vrf_key_config::VrfKeyConfig, vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData}, }; @@ -753,30 +754,36 @@ impl Database for MsSql { #[async_trait] impl PublishQueue for MsSql { - #[instrument(skip(self, raw_label, raw_value), level = "debug")] - async fn enqueue( - &self, - raw_label: Vec, - raw_value: Vec, - ) -> Result<(), PublishQueueError> { + #[instrument(skip(self, label, value), level = "debug")] + async fn enqueue(&self, label: AkdLabel, value: AkdValue) -> Result<(), PublishQueueError> { debug!("Enqueuing item to publish queue"); - let statement = enqueue_statement(raw_label, raw_value); + let statement = enqueue_statement(label, value); self.execute_statement(&statement) .await .map_err(|_| PublishQueueError) } #[instrument(skip(self), level = "debug")] - async fn peek(&self, limit: isize) -> Result, PublishQueueError> { - if limit <= 0 { - debug!("Peek called with non-positive limit, returning empty vector"); - return Ok(vec![]); - } + async fn peek( + &self, + limit: Option, + ) -> Result, PublishQueueError> { + let statement = match limit { + Some(limit) if limit <= 0 => { + warn!("Peek called with non-positive limit, returning empty vector"); + return Ok(vec![]); + } + Some(limit) => { + debug!(limit, "Peeking items from publish queue"); + peek_statement(limit) + } + None => { + debug!("Peeking items from publish queue with no limit"); + peek_no_limit_statement() + } + }; - debug!(limit, "Peeking items from publish queue"); - - let statement = peek_statement(limit); let mut conn = self.get_connection().await.map_err(|_| { error!("Failed to get DB connection for peek"); PublishQueueError @@ -804,6 +811,7 @@ impl PublishQueue for MsSql { queued_items.push(item); } } + debug!( item_count = queued_items.len(), "Peeked items from publish queue" diff --git a/akd/crates/akd_storage/src/ms_sql/tables/publish_queue.rs b/akd/crates/akd_storage/src/ms_sql/tables/publish_queue.rs index 2c4aca21ab..8ec599bead 100644 --- a/akd/crates/akd_storage/src/ms_sql/tables/publish_queue.rs +++ b/akd/crates/akd_storage/src/ms_sql/tables/publish_queue.rs @@ -1,3 +1,4 @@ +use akd::{AkdLabel, AkdValue}; use ms_database::{IntoRow, TokenRow}; use tracing::{debug, error}; use uuid::Uuid; @@ -6,20 +7,17 @@ use crate::{ ms_sql::{ migrations::TABLE_PUBLISH_QUEUE, sql_params::SqlParams, - tables::{ - akd_storable_for_ms_sql::{QueryStatement, Statement}, - temp_table::TempTable, - }, + tables::akd_storable_for_ms_sql::{QueryStatement, Statement}, }, - publish_queue::{PublishQueueError, PublishQueueItem}, + publish_queue::PublishQueueError, }; -pub fn enqueue_statement(raw_label: Vec, raw_value: Vec) -> Statement { +pub fn enqueue_statement(label: AkdLabel, value: AkdValue) -> Statement { debug!("Building enqueue_statement for publish queue"); let mut params = SqlParams::new(); params.add("id", Box::new(uuid::Uuid::now_v7())); - params.add("raw_label", Box::new(raw_label)); - params.add("raw_value", Box::new(raw_value)); + params.add("raw_label", Box::new(label.0)); + params.add("raw_value", Box::new(value.0)); let sql = format!( r#" @@ -40,7 +38,9 @@ pub fn enqueue_statement(raw_label: Vec, raw_value: Vec) -> Statement { Statement::new(sql, params) } -pub fn peek_statement(limit: isize) -> QueryStatement { +pub fn peek_statement( + limit: isize, +) -> QueryStatement<(Uuid, (AkdLabel, AkdValue)), PublishQueueError> { debug!("Building peek_statement for publish queue"); let sql = format!( r#" @@ -49,26 +49,42 @@ pub fn peek_statement(limit: isize) -> QueryStatement QueryStatement<(Uuid, (AkdLabel, AkdValue)), PublishQueueError> +{ + debug!("Building peek_statement with no limit for publish queue"); + let sql = format!( + r#" + SELECT id, raw_label, raw_value + FROM {} + ORDER BY id ASC"#, + TABLE_PUBLISH_QUEUE + ); + QueryStatement::new(sql, SqlParams::new(), publish_queue_item_from_row) +} + +fn publish_queue_item_from_row( + row: &ms_database::Row, +) -> Result<(Uuid, (AkdLabel, AkdValue)), PublishQueueError> { + let id: uuid::Uuid = row.get("id").ok_or_else(|| { + error!("id is NULL or missing in publish queue row"); + PublishQueueError + })?; + let raw_label: &[u8] = row.get("raw_label").ok_or_else(|| { + error!("raw_label is NULL or missing in publish queue row"); + PublishQueueError + })?; + let raw_value: &[u8] = row.get("raw_value").ok_or_else(|| { + error!("raw_value is NULL or missing in publish queue row"); + PublishQueueError + })?; + + Ok(( + id, + (AkdLabel(raw_label.to_vec()), AkdValue(raw_value.to_vec())), + )) } pub fn bulk_delete_rows(ids: &'_ [Uuid]) -> Result>, PublishQueueError> { diff --git a/akd/crates/akd_storage/src/publish_queue.rs b/akd/crates/akd_storage/src/publish_queue.rs index 4b4fcd6fcf..0837342581 100644 --- a/akd/crates/akd_storage/src/publish_queue.rs +++ b/akd/crates/akd_storage/src/publish_queue.rs @@ -1,31 +1,28 @@ +use akd::{AkdLabel, AkdValue}; use async_trait::async_trait; use thiserror::Error; +use uuid::Uuid; use crate::{ - db_config::DatabaseType, - ms_sql::MsSql, - publish_queue_config::{PublishQueueConfig, PublishQueueProvider}, - AkdDatabase, + db_config::DatabaseType, ms_sql::MsSql, publish_queue_config::PublishQueueConfig, AkdDatabase, }; -pub(crate) struct PublishQueueItem { - pub id: uuid::Uuid, - pub raw_label: Vec, - pub raw_value: Vec, -} - #[derive(Debug, Error)] #[error("Publish queue error")] pub struct PublishQueueError; #[async_trait] pub trait PublishQueue { + // TODO: should this method ensure that a given label is not already present in the queue? How to handle that? async fn enqueue( &self, - raw_label: Vec, - raw_value: Vec, + raw_label: AkdLabel, + raw_value: AkdValue, ) -> Result<(), PublishQueueError>; - async fn peek(&self, limit: isize) -> Result, PublishQueueError>; + async fn peek( + &self, + limit: Option, + ) -> Result, PublishQueueError>; async fn remove(&self, ids: Vec) -> Result<(), PublishQueueError>; } @@ -36,8 +33,8 @@ pub enum PublishQueueType { impl PublishQueueType { pub fn new(config: &PublishQueueConfig, db: &AkdDatabase) -> PublishQueueType { - match &config.provider { - PublishQueueProvider::DbBacked => db.into(), + match config { + PublishQueueConfig::DbBacked => db.into(), } } } @@ -54,32 +51,26 @@ impl From<&AkdDatabase> for PublishQueueType { impl PublishQueue for PublishQueueType { async fn enqueue( &self, - _raw_label: Vec, - _raw_value: Vec, + raw_label: AkdLabel, + raw_value: AkdValue, ) -> Result<(), PublishQueueError> { match self { - PublishQueueType::MsSql(_ms_sql) => { - // Implement enqueue logic for MsSql - Ok(()) - } + PublishQueueType::MsSql(ms_sql) => ms_sql.enqueue(raw_label, raw_value).await, } } - async fn peek(&self, _max: isize) -> Result, PublishQueueError> { + async fn peek( + &self, + limit: Option, + ) -> Result, PublishQueueError> { match self { - PublishQueueType::MsSql(_ms_sql) => { - // Implement peek logic for MsSql - Ok(vec![]) - } + PublishQueueType::MsSql(ms_sql) => ms_sql.peek(limit).await, } } - async fn remove(&self, _ids: Vec) -> Result<(), PublishQueueError> { + async fn remove(&self, ids: Vec) -> Result<(), PublishQueueError> { match self { - PublishQueueType::MsSql(_ms_sql) => { - // Implement remove logic for MsSql - Ok(()) - } + PublishQueueType::MsSql(ms_sql) => ms_sql.remove(ids).await, } } } diff --git a/akd/crates/akd_storage/src/publish_queue_config.rs b/akd/crates/akd_storage/src/publish_queue_config.rs index e6499fb25e..3867db04af 100644 --- a/akd/crates/akd_storage/src/publish_queue_config.rs +++ b/akd/crates/akd_storage/src/publish_queue_config.rs @@ -2,19 +2,8 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] #[serde(tag = "type")] -pub struct PublishQueueConfig { - pub provider: PublishQueueProvider, - #[serde(default = "default_publish_limit")] - pub epoch_update_limit: Option, -} - -fn default_publish_limit() -> Option { - None -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "type")] -pub enum PublishQueueProvider { +pub enum PublishQueueConfig { + /// Database-backed publish queue DbBacked, } diff --git a/akd/crates/akd_test_utility/src/main.rs b/akd/crates/akd_test_utility/src/main.rs index 99a86cf474..140a781810 100644 --- a/akd/crates/akd_test_utility/src/main.rs +++ b/akd/crates/akd_test_utility/src/main.rs @@ -169,10 +169,7 @@ async fn main() -> Result<()> { vrf_key_config: akd_storage::vrf_key_config::VrfKeyConfig::B64EncodedSymmetricKey { key: "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=".to_string(), }, - publish_queue_config: akd_storage::publish_queue_config::PublishQueueConfig { - provider: akd_storage::publish_queue_config::PublishQueueProvider::DbBacked, - epoch_update_limit: None, - }, + publish_queue_config: akd_storage::publish_queue_config::PublishQueueConfig::DbBacked, }; let (mut directory, db, _) = config .initialize_directory::() diff --git a/akd/crates/publisher/src/config.rs b/akd/crates/publisher/src/config.rs index 1fcdd812b6..d65fcc1d44 100644 --- a/akd/crates/publisher/src/config.rs +++ b/akd/crates/publisher/src/config.rs @@ -9,20 +9,37 @@ const DEFAULT_EPOCH_DURATION_MS: u64 = 30000; // 30 seconds pub struct ApplicationConfig { pub storage: AkdStorageConfig, pub publisher: PublisherConfig, + /// The unique Bitwarden installation ID using this AKD publisher instance. + /// This value is used to namespace AKD data to a given installation. pub installation_id: Uuid, + /// The address the web server will bind to. Defaults to "127.0.0.1:3000". + #[serde(default = "default_web_server_bind_address")] + web_server_bind_address: String, // web_server: WebServerConfig, } +fn default_web_server_bind_address() -> String { + "127.0.0.1:3000".to_string() +} + #[derive(Clone, Debug, Deserialize)] pub struct PublisherConfig { + /// The duration of each publishing epoch in milliseconds. Defaults to 30 seconds. #[serde(default = "default_epoch_duration_ms")] - epoch_duration_ms: u64, + pub epoch_duration_ms: u64, + /// The limit to the number of AKD values to update in a single epoch. Defaults to no limit. + #[serde(default = "default_epoch_update_limit")] + pub epoch_update_limit: Option, } fn default_epoch_duration_ms() -> u64 { DEFAULT_EPOCH_DURATION_MS } +fn default_epoch_update_limit() -> Option { + None +} + impl ApplicationConfig { /// Load configuration from multiple sources in order of priority: /// 1. Environment variables (prefixed with AKD_PUBLISHER) - always applied with highest priority @@ -65,6 +82,14 @@ impl ApplicationConfig { self.publisher.validate()?; Ok(()) } + + /// Get the web server bind address as a SocketAddr + /// Panics if the address is invalid + pub fn socket_address(&self) -> std::net::SocketAddr { + self.web_server_bind_address + .parse() + .expect("Invalid web server bind address") + } } impl PublisherConfig { diff --git a/akd/crates/publisher/src/lib.rs b/akd/crates/publisher/src/lib.rs index 2ad4d41ca5..9ce54fbe85 100644 --- a/akd/crates/publisher/src/lib.rs +++ b/akd/crates/publisher/src/lib.rs @@ -1,7 +1,10 @@ -use anyhow::Result; +use akd_storage::{PublishQueue, PublishQueueType}; +use anyhow::{Context, Result}; +use axum::Router; use bitwarden_akd_configuration::BitwardenV1Configuration; -use tokio::sync::broadcast::Receiver; -use tracing::{info, instrument}; +use common::BitAkdDirectory; +use tokio::{net::TcpListener, sync::broadcast::Receiver}; +use tracing::{error, info, instrument, trace}; mod config; mod routes; @@ -15,28 +18,33 @@ pub struct AppHandles { #[instrument(skip_all, name = "publisher_start")] pub async fn start(config: ApplicationConfig, shutdown_rx: &Receiver<()>) -> Result { - let (directory, db, publish_queue) = config + let (directory, _, publish_queue) = config .storage .initialize_directory::() .await?; // Initialize write job let write_handle = { - let mut shutdown_rx = shutdown_rx.resubscribe(); + let shutdown_rx = shutdown_rx.resubscribe(); + let publish_queue = publish_queue.clone(); + let config = config.clone(); + tokio::spawn(async move { - // wait until shutdown signal is received - shutdown_rx.recv().await.ok(); - info!("Shutting down publisher write job"); + if let Err(e) = start_publisher(directory, publish_queue, &config, shutdown_rx).await { + error!(err = %e, "Publisher write job failed"); + } }) }; // Initialize web server let web_handle = { - let mut shutdown_rx = shutdown_rx.resubscribe(); + let shutdown_rx = shutdown_rx.resubscribe(); + let publish_queue = publish_queue.clone(); + tokio::spawn(async move { - // wait forever until shutdown signal is received - shutdown_rx.recv().await.ok(); - info!("Shutting down publisher web server"); + if let Err(e) = start_web(publish_queue, &config, shutdown_rx).await { + error!(err = %e, "Web server failed"); + } }) }; @@ -45,3 +53,98 @@ pub async fn start(config: ApplicationConfig, shutdown_rx: &Receiver<()>) -> Res web_handle, }) } + +#[instrument(skip_all)] +async fn start_publisher( + directory: BitAkdDirectory, + publish_queue: PublishQueueType, + config: &ApplicationConfig, + mut shutdown_rx: Receiver<()>, +) -> Result<()> { + let mut next_epoch = tokio::time::Instant::now() + + std::time::Duration::from_millis(config.publisher.epoch_duration_ms as u64); + loop { + trace!("Processing publish queue for epoch"); + + // Pull items from publish queue + let (ids, items) = publish_queue + .peek(config.publisher.epoch_update_limit) + .await? + .into_iter() + .fold((vec![], vec![]), |mut acc, i| { + acc.0.push(i.0); + acc.1.push(i.1); + acc + }); + + let result: Result<()> = { + // Apply items to directory + directory + .publish(items) + .await + .context("AKD publish failed")?; + + // Remove processed items from publish queue + publish_queue + .remove(ids) + .await + .context("Failed to remove processed publish queue items")?; + Ok(()) + }; + + if let Err(e) = result { + error!(%e, "Error processing publish queue items"); + //TODO: What actions to take to recover? + return Err(anyhow::anyhow!("Error processing publish queue items")); + }; + + info!( + approx_wait_in_sec = next_epoch + .duration_since(tokio::time::Instant::now()) + .as_secs_f64(), + "Waiting for next epoch or shutdown signal" + ); + tokio::select! { + _ = shutdown_rx.recv() => { + info!("Shutting down publisher job"); + break; + } + // Sleep until next epoch + _ = tokio::time::sleep_until(next_epoch) => { + // Continue to process publish queue + next_epoch = tokio::time::Instant::now() + + std::time::Duration::from_millis(config.publisher.epoch_duration_ms as u64); + } + }; + } + + Ok(()) +} + +#[instrument(skip_all)] +async fn start_web( + publish_queue: PublishQueueType, + config: &ApplicationConfig, + mut shutdown_rx: Receiver<()>, +) -> Result<()> { + let app_state = routes::AppState { publish_queue }; + let app = Router::new() + .merge(routes::api_routes()) + .with_state(app_state); + + let listener = TcpListener::bind(&config.socket_address()) + .await + .context("Socket bind failed")?; + info!( + "Publisher web server listening on {}", + config.socket_address() + ); + axum::serve(listener, app.into_make_service()) + .with_graceful_shutdown(async move { + shutdown_rx.recv().await.ok(); + }) + .await + .context("Web server failed")?; + + Ok(()) +} diff --git a/akd/crates/publisher/src/routes/health.rs b/akd/crates/publisher/src/routes/health.rs index a70e96939b..a299442bc7 100644 --- a/akd/crates/publisher/src/routes/health.rs +++ b/akd/crates/publisher/src/routes/health.rs @@ -1,6 +1,6 @@ use axum::Json; use serde::{Deserialize, Serialize}; -use tracing::{error, info, instrument}; +use tracing::{info, instrument}; #[derive(Debug, Serialize, Deserialize)] pub struct ServerHealth { diff --git a/akd/crates/publisher/src/routes/mod.rs b/akd/crates/publisher/src/routes/mod.rs index 2440c8271f..b11dce1d23 100644 --- a/akd/crates/publisher/src/routes/mod.rs +++ b/akd/crates/publisher/src/routes/mod.rs @@ -1,14 +1,11 @@ -use akd_storage::{AkdDatabase, PublishQueueType}; +use akd_storage::PublishQueueType; use axum::routing::{get, post}; -use common::BitAkdDirectory; mod health; mod publish; #[derive(Clone)] pub(crate) struct AppState { - pub directory: BitAkdDirectory, - pub db: AkdDatabase, pub publish_queue: PublishQueueType, } diff --git a/akd/crates/publisher/src/routes/publish.rs b/akd/crates/publisher/src/routes/publish.rs index 782aa3df8b..7ba7e80544 100644 --- a/akd/crates/publisher/src/routes/publish.rs +++ b/akd/crates/publisher/src/routes/publish.rs @@ -1,4 +1,5 @@ use super::AppState; +use akd::{AkdLabel, AkdValue}; use akd_storage::PublishQueue; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; use serde::{Deserialize, Serialize}; @@ -22,8 +23,8 @@ pub async fn publish_handler( ) -> impl IntoResponse { info!("Handling publish request"); - let akd_label: Vec = request.akd_label_b64.into_bytes(); - let akd_value: Vec = request.akd_value_b64.into_bytes(); + let akd_label: AkdLabel = AkdLabel(request.akd_label_b64.into_bytes()); + let akd_value: AkdValue = AkdValue(request.akd_value_b64.into_bytes()); //TODO: enqueue publish operation to to_publish queue if let Err(e) = publish_queue.enqueue(akd_label, akd_value).await {