1
0
mirror of https://github.com/bitwarden/server synced 2026-01-28 23:36:12 +00:00

First working build of publisher application

This commit is contained in:
Matt Gibson
2026-01-14 14:58:17 -08:00
parent 2c907f14ab
commit 8aa7141306
10 changed files with 241 additions and 114 deletions

View File

@@ -25,16 +25,17 @@ use tables::{
temp_table::TempTable,
values,
};
use uuid::Uuid;
use crate::{
ms_sql::tables::{
akd_storable_for_ms_sql::QueryStatement,
publish_queue::{
bulk_delete_rows, bulk_delete_statement, enqueue_statement, peek_statement,
bulk_delete_rows, bulk_delete_statement, enqueue_statement, peek_no_limit_statement,
peek_statement,
},
vrf_key,
},
publish_queue::{PublishQueue, PublishQueueError, PublishQueueItem},
publish_queue::{PublishQueue, PublishQueueError},
vrf_key_config::VrfKeyConfig,
vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData},
};
@@ -753,30 +754,36 @@ impl Database for MsSql {
#[async_trait]
impl PublishQueue for MsSql {
#[instrument(skip(self, raw_label, raw_value), level = "debug")]
async fn enqueue(
&self,
raw_label: Vec<u8>,
raw_value: Vec<u8>,
) -> Result<(), PublishQueueError> {
#[instrument(skip(self, label, value), level = "debug")]
async fn enqueue(&self, label: AkdLabel, value: AkdValue) -> Result<(), PublishQueueError> {
debug!("Enqueuing item to publish queue");
let statement = enqueue_statement(raw_label, raw_value);
let statement = enqueue_statement(label, value);
self.execute_statement(&statement)
.await
.map_err(|_| PublishQueueError)
}
#[instrument(skip(self), level = "debug")]
async fn peek(&self, limit: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError> {
if limit <= 0 {
debug!("Peek called with non-positive limit, returning empty vector");
return Ok(vec![]);
}
async fn peek(
&self,
limit: Option<isize>,
) -> Result<Vec<(Uuid, (AkdLabel, AkdValue))>, PublishQueueError> {
let statement = match limit {
Some(limit) if limit <= 0 => {
warn!("Peek called with non-positive limit, returning empty vector");
return Ok(vec![]);
}
Some(limit) => {
debug!(limit, "Peeking items from publish queue");
peek_statement(limit)
}
None => {
debug!("Peeking items from publish queue with no limit");
peek_no_limit_statement()
}
};
debug!(limit, "Peeking items from publish queue");
let statement = peek_statement(limit);
let mut conn = self.get_connection().await.map_err(|_| {
error!("Failed to get DB connection for peek");
PublishQueueError
@@ -804,6 +811,7 @@ impl PublishQueue for MsSql {
queued_items.push(item);
}
}
debug!(
item_count = queued_items.len(),
"Peeked items from publish queue"

View File

@@ -1,3 +1,4 @@
use akd::{AkdLabel, AkdValue};
use ms_database::{IntoRow, TokenRow};
use tracing::{debug, error};
use uuid::Uuid;
@@ -6,20 +7,17 @@ use crate::{
ms_sql::{
migrations::TABLE_PUBLISH_QUEUE,
sql_params::SqlParams,
tables::{
akd_storable_for_ms_sql::{QueryStatement, Statement},
temp_table::TempTable,
},
tables::akd_storable_for_ms_sql::{QueryStatement, Statement},
},
publish_queue::{PublishQueueError, PublishQueueItem},
publish_queue::PublishQueueError,
};
pub fn enqueue_statement(raw_label: Vec<u8>, raw_value: Vec<u8>) -> Statement {
pub fn enqueue_statement(label: AkdLabel, value: AkdValue) -> Statement {
debug!("Building enqueue_statement for publish queue");
let mut params = SqlParams::new();
params.add("id", Box::new(uuid::Uuid::now_v7()));
params.add("raw_label", Box::new(raw_label));
params.add("raw_value", Box::new(raw_value));
params.add("raw_label", Box::new(label.0));
params.add("raw_value", Box::new(value.0));
let sql = format!(
r#"
@@ -40,7 +38,9 @@ pub fn enqueue_statement(raw_label: Vec<u8>, raw_value: Vec<u8>) -> Statement {
Statement::new(sql, params)
}
pub fn peek_statement(limit: isize) -> QueryStatement<PublishQueueItem, PublishQueueError> {
pub fn peek_statement(
limit: isize,
) -> QueryStatement<(Uuid, (AkdLabel, AkdValue)), PublishQueueError> {
debug!("Building peek_statement for publish queue");
let sql = format!(
r#"
@@ -49,26 +49,42 @@ pub fn peek_statement(limit: isize) -> QueryStatement<PublishQueueItem, PublishQ
ORDER BY id ASC"#,
limit, TABLE_PUBLISH_QUEUE
);
QueryStatement::new(sql, SqlParams::new(), |row: &ms_database::Row| {
let id: uuid::Uuid = row.get("id").ok_or_else(|| {
error!("id is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_label: &[u8] = row.get("raw_label").ok_or_else(|| {
error!("raw_label is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_value: &[u8] = row.get("raw_value").ok_or_else(|| {
error!("raw_value is NULL or missing in publish queue row");
PublishQueueError
})?;
QueryStatement::new(sql, SqlParams::new(), publish_queue_item_from_row)
}
Ok(PublishQueueItem {
id,
raw_label: raw_label.to_vec(),
raw_value: raw_value.to_vec(),
})
})
pub fn peek_no_limit_statement() -> QueryStatement<(Uuid, (AkdLabel, AkdValue)), PublishQueueError>
{
debug!("Building peek_statement with no limit for publish queue");
let sql = format!(
r#"
SELECT id, raw_label, raw_value
FROM {}
ORDER BY id ASC"#,
TABLE_PUBLISH_QUEUE
);
QueryStatement::new(sql, SqlParams::new(), publish_queue_item_from_row)
}
fn publish_queue_item_from_row(
row: &ms_database::Row,
) -> Result<(Uuid, (AkdLabel, AkdValue)), PublishQueueError> {
let id: uuid::Uuid = row.get("id").ok_or_else(|| {
error!("id is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_label: &[u8] = row.get("raw_label").ok_or_else(|| {
error!("raw_label is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_value: &[u8] = row.get("raw_value").ok_or_else(|| {
error!("raw_value is NULL or missing in publish queue row");
PublishQueueError
})?;
Ok((
id,
(AkdLabel(raw_label.to_vec()), AkdValue(raw_value.to_vec())),
))
}
pub fn bulk_delete_rows(ids: &'_ [Uuid]) -> Result<Vec<TokenRow<'_>>, PublishQueueError> {

View File

@@ -1,31 +1,28 @@
use akd::{AkdLabel, AkdValue};
use async_trait::async_trait;
use thiserror::Error;
use uuid::Uuid;
use crate::{
db_config::DatabaseType,
ms_sql::MsSql,
publish_queue_config::{PublishQueueConfig, PublishQueueProvider},
AkdDatabase,
db_config::DatabaseType, ms_sql::MsSql, publish_queue_config::PublishQueueConfig, AkdDatabase,
};
pub(crate) struct PublishQueueItem {
pub id: uuid::Uuid,
pub raw_label: Vec<u8>,
pub raw_value: Vec<u8>,
}
#[derive(Debug, Error)]
#[error("Publish queue error")]
pub struct PublishQueueError;
#[async_trait]
pub trait PublishQueue {
// TODO: should this method ensure that a given label is not already present in the queue? How to handle that?
async fn enqueue(
&self,
raw_label: Vec<u8>,
raw_value: Vec<u8>,
raw_label: AkdLabel,
raw_value: AkdValue,
) -> Result<(), PublishQueueError>;
async fn peek(&self, limit: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError>;
async fn peek(
&self,
limit: Option<isize>,
) -> Result<Vec<(Uuid, (AkdLabel, AkdValue))>, PublishQueueError>;
async fn remove(&self, ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError>;
}
@@ -36,8 +33,8 @@ pub enum PublishQueueType {
impl PublishQueueType {
pub fn new(config: &PublishQueueConfig, db: &AkdDatabase) -> PublishQueueType {
match &config.provider {
PublishQueueProvider::DbBacked => db.into(),
match config {
PublishQueueConfig::DbBacked => db.into(),
}
}
}
@@ -54,32 +51,26 @@ impl From<&AkdDatabase> for PublishQueueType {
impl PublishQueue for PublishQueueType {
async fn enqueue(
&self,
_raw_label: Vec<u8>,
_raw_value: Vec<u8>,
raw_label: AkdLabel,
raw_value: AkdValue,
) -> Result<(), PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement enqueue logic for MsSql
Ok(())
}
PublishQueueType::MsSql(ms_sql) => ms_sql.enqueue(raw_label, raw_value).await,
}
}
async fn peek(&self, _max: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError> {
async fn peek(
&self,
limit: Option<isize>,
) -> Result<Vec<(Uuid, (AkdLabel, AkdValue))>, PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement peek logic for MsSql
Ok(vec![])
}
PublishQueueType::MsSql(ms_sql) => ms_sql.peek(limit).await,
}
}
async fn remove(&self, _ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError> {
async fn remove(&self, ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement remove logic for MsSql
Ok(())
}
PublishQueueType::MsSql(ms_sql) => ms_sql.remove(ids).await,
}
}
}

View File

@@ -2,19 +2,8 @@ use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub struct PublishQueueConfig {
pub provider: PublishQueueProvider,
#[serde(default = "default_publish_limit")]
pub epoch_update_limit: Option<isize>,
}
fn default_publish_limit() -> Option<isize> {
None
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum PublishQueueProvider {
pub enum PublishQueueConfig {
/// Database-backed publish queue
DbBacked,
}

View File

@@ -169,10 +169,7 @@ async fn main() -> Result<()> {
vrf_key_config: akd_storage::vrf_key_config::VrfKeyConfig::B64EncodedSymmetricKey {
key: "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=".to_string(),
},
publish_queue_config: akd_storage::publish_queue_config::PublishQueueConfig {
provider: akd_storage::publish_queue_config::PublishQueueProvider::DbBacked,
epoch_update_limit: None,
},
publish_queue_config: akd_storage::publish_queue_config::PublishQueueConfig::DbBacked,
};
let (mut directory, db, _) = config
.initialize_directory::<TC>()

View File

@@ -9,20 +9,37 @@ const DEFAULT_EPOCH_DURATION_MS: u64 = 30000; // 30 seconds
pub struct ApplicationConfig {
pub storage: AkdStorageConfig,
pub publisher: PublisherConfig,
/// The unique Bitwarden installation ID using this AKD publisher instance.
/// This value is used to namespace AKD data to a given installation.
pub installation_id: Uuid,
/// The address the web server will bind to. Defaults to "127.0.0.1:3000".
#[serde(default = "default_web_server_bind_address")]
web_server_bind_address: String,
// web_server: WebServerConfig,
}
fn default_web_server_bind_address() -> String {
"127.0.0.1:3000".to_string()
}
#[derive(Clone, Debug, Deserialize)]
pub struct PublisherConfig {
/// The duration of each publishing epoch in milliseconds. Defaults to 30 seconds.
#[serde(default = "default_epoch_duration_ms")]
epoch_duration_ms: u64,
pub epoch_duration_ms: u64,
/// The limit to the number of AKD values to update in a single epoch. Defaults to no limit.
#[serde(default = "default_epoch_update_limit")]
pub epoch_update_limit: Option<isize>,
}
fn default_epoch_duration_ms() -> u64 {
DEFAULT_EPOCH_DURATION_MS
}
fn default_epoch_update_limit() -> Option<isize> {
None
}
impl ApplicationConfig {
/// Load configuration from multiple sources in order of priority:
/// 1. Environment variables (prefixed with AKD_PUBLISHER) - always applied with highest priority
@@ -65,6 +82,14 @@ impl ApplicationConfig {
self.publisher.validate()?;
Ok(())
}
/// Get the web server bind address as a SocketAddr
/// Panics if the address is invalid
pub fn socket_address(&self) -> std::net::SocketAddr {
self.web_server_bind_address
.parse()
.expect("Invalid web server bind address")
}
}
impl PublisherConfig {

View File

@@ -1,7 +1,10 @@
use anyhow::Result;
use akd_storage::{PublishQueue, PublishQueueType};
use anyhow::{Context, Result};
use axum::Router;
use bitwarden_akd_configuration::BitwardenV1Configuration;
use tokio::sync::broadcast::Receiver;
use tracing::{info, instrument};
use common::BitAkdDirectory;
use tokio::{net::TcpListener, sync::broadcast::Receiver};
use tracing::{error, info, instrument, trace};
mod config;
mod routes;
@@ -15,28 +18,33 @@ pub struct AppHandles {
#[instrument(skip_all, name = "publisher_start")]
pub async fn start(config: ApplicationConfig, shutdown_rx: &Receiver<()>) -> Result<AppHandles> {
let (directory, db, publish_queue) = config
let (directory, _, publish_queue) = config
.storage
.initialize_directory::<BitwardenV1Configuration>()
.await?;
// Initialize write job
let write_handle = {
let mut shutdown_rx = shutdown_rx.resubscribe();
let shutdown_rx = shutdown_rx.resubscribe();
let publish_queue = publish_queue.clone();
let config = config.clone();
tokio::spawn(async move {
// wait until shutdown signal is received
shutdown_rx.recv().await.ok();
info!("Shutting down publisher write job");
if let Err(e) = start_publisher(directory, publish_queue, &config, shutdown_rx).await {
error!(err = %e, "Publisher write job failed");
}
})
};
// Initialize web server
let web_handle = {
let mut shutdown_rx = shutdown_rx.resubscribe();
let shutdown_rx = shutdown_rx.resubscribe();
let publish_queue = publish_queue.clone();
tokio::spawn(async move {
// wait forever until shutdown signal is received
shutdown_rx.recv().await.ok();
info!("Shutting down publisher web server");
if let Err(e) = start_web(publish_queue, &config, shutdown_rx).await {
error!(err = %e, "Web server failed");
}
})
};
@@ -45,3 +53,98 @@ pub async fn start(config: ApplicationConfig, shutdown_rx: &Receiver<()>) -> Res
web_handle,
})
}
#[instrument(skip_all)]
async fn start_publisher(
directory: BitAkdDirectory,
publish_queue: PublishQueueType,
config: &ApplicationConfig,
mut shutdown_rx: Receiver<()>,
) -> Result<()> {
let mut next_epoch = tokio::time::Instant::now()
+ std::time::Duration::from_millis(config.publisher.epoch_duration_ms as u64);
loop {
trace!("Processing publish queue for epoch");
// Pull items from publish queue
let (ids, items) = publish_queue
.peek(config.publisher.epoch_update_limit)
.await?
.into_iter()
.fold((vec![], vec![]), |mut acc, i| {
acc.0.push(i.0);
acc.1.push(i.1);
acc
});
let result: Result<()> = {
// Apply items to directory
directory
.publish(items)
.await
.context("AKD publish failed")?;
// Remove processed items from publish queue
publish_queue
.remove(ids)
.await
.context("Failed to remove processed publish queue items")?;
Ok(())
};
if let Err(e) = result {
error!(%e, "Error processing publish queue items");
//TODO: What actions to take to recover?
return Err(anyhow::anyhow!("Error processing publish queue items"));
};
info!(
approx_wait_in_sec = next_epoch
.duration_since(tokio::time::Instant::now())
.as_secs_f64(),
"Waiting for next epoch or shutdown signal"
);
tokio::select! {
_ = shutdown_rx.recv() => {
info!("Shutting down publisher job");
break;
}
// Sleep until next epoch
_ = tokio::time::sleep_until(next_epoch) => {
// Continue to process publish queue
next_epoch = tokio::time::Instant::now()
+ std::time::Duration::from_millis(config.publisher.epoch_duration_ms as u64);
}
};
}
Ok(())
}
#[instrument(skip_all)]
async fn start_web(
publish_queue: PublishQueueType,
config: &ApplicationConfig,
mut shutdown_rx: Receiver<()>,
) -> Result<()> {
let app_state = routes::AppState { publish_queue };
let app = Router::new()
.merge(routes::api_routes())
.with_state(app_state);
let listener = TcpListener::bind(&config.socket_address())
.await
.context("Socket bind failed")?;
info!(
"Publisher web server listening on {}",
config.socket_address()
);
axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(async move {
shutdown_rx.recv().await.ok();
})
.await
.context("Web server failed")?;
Ok(())
}

View File

@@ -1,6 +1,6 @@
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::{error, info, instrument};
use tracing::{info, instrument};
#[derive(Debug, Serialize, Deserialize)]
pub struct ServerHealth {

View File

@@ -1,14 +1,11 @@
use akd_storage::{AkdDatabase, PublishQueueType};
use akd_storage::PublishQueueType;
use axum::routing::{get, post};
use common::BitAkdDirectory;
mod health;
mod publish;
#[derive(Clone)]
pub(crate) struct AppState {
pub directory: BitAkdDirectory,
pub db: AkdDatabase,
pub publish_queue: PublishQueueType,
}

View File

@@ -1,4 +1,5 @@
use super::AppState;
use akd::{AkdLabel, AkdValue};
use akd_storage::PublishQueue;
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
@@ -22,8 +23,8 @@ pub async fn publish_handler(
) -> impl IntoResponse {
info!("Handling publish request");
let akd_label: Vec<u8> = request.akd_label_b64.into_bytes();
let akd_value: Vec<u8> = request.akd_value_b64.into_bytes();
let akd_label: AkdLabel = AkdLabel(request.akd_label_b64.into_bytes());
let akd_value: AkdValue = AkdValue(request.akd_value_b64.into_bytes());
//TODO: enqueue publish operation to to_publish queue
if let Err(e) = publish_queue.enqueue(akd_label, akd_value).await {