1
0
mirror of https://github.com/bitwarden/server synced 2026-01-30 00:03:48 +00:00

Add publish queue and web handler

This commit is contained in:
Matt Gibson
2026-01-13 14:16:45 -08:00
parent a4fca3dfe6
commit 2ad61ff10a
16 changed files with 421 additions and 42 deletions

2
akd/Cargo.lock generated
View File

@@ -102,6 +102,7 @@ dependencies = [
"thiserror 2.0.17",
"tokio",
"tracing",
"uuid",
]
[[package]]
@@ -3127,6 +3128,7 @@ version = "1.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2"
dependencies = [
"getrandom 0.3.3",
"js-sys",
"serde",
"wasm-bindgen",

View File

@@ -21,6 +21,7 @@ rsa = { version = ">=0.9.2, <0.10" }
serde = { workspace = true }
thiserror.workspace = true
tracing.workspace = true
uuid = { workspace = true, features = ["v7", "serde"] }
[dev-dependencies]
tokio.workspace = true

View File

@@ -0,0 +1 @@
DROP TABLE IF EXISTS dbo.akd_publish_queue;

View File

@@ -0,0 +1,11 @@
IF OBJECT_ID('dbo.akd_publish_queue', 'U') IS NULL
BEGIN
CREATE TABLE dbo.akd_publish_queue (
id UNIQUEIDENTIFIER NOT NULL DEFAULT NEWID() PRIMARY KEY,
raw_label VARBINARY(256) NOT NULL UNIQUE,
raw_value VARBINARY(2000) NULL,
);
CREATE UNIQUE INDEX IX_akd_publish_queue_raw_label
ON dbo.akd_publish_queue(raw_label);
END

View File

@@ -2,8 +2,10 @@ mod akd_database;
pub mod akd_storage_config;
pub mod db_config;
pub mod ms_sql;
mod publish_queue;
pub mod vrf_key_config;
pub mod vrf_key_database;
pub use akd_database::*;
pub use publish_queue::*;
pub use vrf_key_database::*;

View File

@@ -5,6 +5,7 @@ pub const TABLE_AZKS: &str = "akd_azks";
pub const TABLE_HISTORY_TREE_NODES: &str = "akd_history_tree_nodes";
pub const TABLE_VALUES: &str = "akd_values";
pub const TABLE_VRF_KEYS: &str = "akd_vrf_keys";
pub const TABLE_PUBLISH_QUEUE: &str = "akd_publish_queue";
pub const TABLE_MIGRATIONS: &str = ms_database::TABLE_MIGRATIONS;
pub(crate) const MIGRATIONS: &[Migration] = load_migrations!("migrations/ms_sql");

View File

@@ -17,8 +17,8 @@ use ms_database::{IntoRow, MsSqlConnectionManager, Pool, PooledConnection};
use tracing::{debug, error, info, instrument, trace, warn};
use migrations::{
MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES,
TABLE_VRF_KEYS,
MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_PUBLISH_QUEUE,
TABLE_VALUES, TABLE_VRF_KEYS,
};
use tables::{
akd_storable_for_ms_sql::{AkdStorableForMsSql, Statement},
@@ -27,7 +27,14 @@ use tables::{
};
use crate::{
ms_sql::tables::vrf_key,
ms_sql::tables::{
akd_storable_for_ms_sql::QueryStatement,
publish_queue::{
bulk_delete_rows, bulk_delete_statement, enqueue_statement, peek_statement,
},
vrf_key,
},
publish_queue::{PublishQueue, PublishQueueError, PublishQueueItem},
vrf_key_config::VrfKeyConfig,
vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData},
};
@@ -124,6 +131,7 @@ impl MsSql {
DROP TABLE IF EXISTS {TABLE_HISTORY_TREE_NODES};
DROP TABLE IF EXISTS {TABLE_VALUES};
DROP TABLE IF EXISTS {TABLE_VRF_KEYS};
DROP TABLE IF EXISTS {TABLE_PUBLISH_QUEUE};
DROP TABLE IF EXISTS {TABLE_MIGRATIONS};"#
);
@@ -170,9 +178,7 @@ impl MsSql {
debug!("Statement executed successfully");
Ok(())
}
}
impl MsSql {
pub async fn get_existing_vrf_root_key_hash(
&self,
) -> Result<Option<Vec<u8>>, VrfKeyStorageError> {
@@ -209,7 +215,7 @@ impl MsSql {
}
#[instrument(skip(self, config), level = "debug")]
pub async fn get_vrf_key(
pub(crate) async fn get_vrf_key(
&self,
config: &VrfKeyConfig,
) -> Result<VrfKeyTableData, VrfKeyRetrievalError> {
@@ -253,7 +259,7 @@ impl MsSql {
}
#[instrument(skip(self, table_data), level = "debug")]
pub async fn store_vrf_key(
pub(crate) async fn store_vrf_key(
&self,
table_data: &VrfKeyTableData,
) -> Result<(), VrfKeyStorageError> {
@@ -744,3 +750,153 @@ impl Database for MsSql {
}
}
}
#[async_trait]
impl PublishQueue for MsSql {
#[instrument(skip(self, raw_label, raw_value), level = "debug")]
async fn enqueue(
&self,
raw_label: Vec<u8>,
raw_value: Vec<u8>,
) -> Result<(), PublishQueueError> {
debug!("Enqueuing item to publish queue");
let statement = enqueue_statement(raw_label, raw_value);
self.execute_statement(&statement)
.await
.map_err(|_| PublishQueueError)
}
#[instrument(skip(self), level = "debug")]
async fn peek(&self, limit: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError> {
if limit <= 0 {
debug!("Peek called with non-positive limit, returning empty vector");
return Ok(vec![]);
}
debug!(limit, "Peeking items from publish queue");
let statement = peek_statement(limit);
let mut conn = self.get_connection().await.map_err(|_| {
error!("Failed to get DB connection for peek");
PublishQueueError
})?;
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|e| {
error!(error = %e, "Failed to execute peek query");
PublishQueueError
})?;
let mut queued_items = Vec::new();
{
let rows = query_stream.into_first_result().await.map_err(|e| {
error!(error = %e, "Failed to fetch rows for peek");
PublishQueueError
})?;
for row in rows {
let item = statement.parse(&row).map_err(|e| {
error!(error = %e, "Failed to parse publish queue item");
PublishQueueError
})?;
queued_items.push(item);
}
}
debug!(
item_count = queued_items.len(),
"Peeked items from publish queue"
);
Ok(queued_items)
}
#[instrument(skip(self), level = "debug")]
async fn remove(&self, ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError> {
if ids.is_empty() {
debug!("No IDs provided for removal, skipping operation");
return Ok(());
}
debug!(id_count = ids.len(), "Removing items from publish queue");
let temp_table = TempTable::PublishQueueIds;
let create_temp_table = temp_table.create();
let temp_table_name = &temp_table.to_string();
let mut conn = self.get_connection().await.map_err(|_| {
error!("Failed to get DB connection for remove");
PublishQueueError
})?;
debug!("Beginning transaction for remove");
conn.simple_query("BEGIN TRANSACTION").await.map_err(|e| {
error!(error = %e, "Failed to begin transaction");
PublishQueueError
})?;
let result = async {
debug!("creating temp table for IDs");
conn.simple_query(&create_temp_table).await.map_err(|e| {
error!(error = %e, "Failed to create temp table");
PublishQueueError
})?;
let mut bulk = conn.bulk_insert(temp_table_name).await.map_err(|e| {
error!(error = %e, "Failed to start bulk insert");
PublishQueueError
})?;
for row in bulk_delete_rows(&ids)? {
bulk.send(row).await.map_err(|e| {
error!(error = %e, "Failed to add row to bulk insert");
PublishQueueError
})?;
}
bulk.finalize().await.map_err(|e| {
error!(error = %e, "Failed to finalize bulk insert");
PublishQueueError
})?;
debug!("Deleting rows from publish queue");
let delete_statement = bulk_delete_statement(temp_table_name);
conn.simple_query(&delete_statement.sql())
.await
.map_err(|e| {
error!(error = %e, "Failed to execute delete statement");
PublishQueueError
})?;
debug!("Dropping temp table");
let drop_temp_table = temp_table.drop();
conn.simple_query(&drop_temp_table).await.map_err(|e| {
error!(error = %e, "Failed to drop temp table");
PublishQueueError
})?;
Ok(())
};
match result.await {
Ok(_) => {
debug!("Committing transaction for delete");
conn.simple_query("COMMIT").await.map_err(|e| {
error!(error = %e, "Failed to commit transaction");
PublishQueueError
})?;
info!(
id_count = ids.len(),
"Successfully removed items from publish queue"
);
Ok(())
}
Err(e) => {
warn!(error = %e, "Remove failed, rolling back transaction");
conn.simple_query("ROLLBACK").await.map_err(|e| {
error!(error = %e, "Failed to roll back transaction");
PublishQueueError
})?;
error!(error = %e, "Remove rolled back");
Err(e)
}
}
}
}

View File

@@ -70,18 +70,14 @@ impl Statement {
}
}
pub(crate) struct QueryStatement<Out> {
pub(crate) struct QueryStatement<Out, TError> {
sql: String,
params: SqlParams,
parser: fn(&Row) -> Result<Out, StorageError>,
parser: fn(&Row) -> Result<Out, TError>,
}
impl<Out> QueryStatement<Out> {
pub fn new(
sql: String,
params: SqlParams,
parser: fn(&Row) -> Result<Out, StorageError>,
) -> Self {
impl<Out, TError> QueryStatement<Out, TError> {
pub fn new(sql: String, params: SqlParams, parser: fn(&Row) -> Result<Out, TError>) -> Self {
Self {
sql,
params,
@@ -97,7 +93,7 @@ impl<Out> QueryStatement<Out> {
self.params.values()
}
pub fn parse(&self, row: &Row) -> Result<Out, StorageError> {
pub fn parse(&self, row: &Row) -> Result<Out, TError> {
(self.parser)(row)
}
}

View File

@@ -1,4 +1,5 @@
pub(crate) mod akd_storable_for_ms_sql;
pub(crate) mod publish_queue;
pub(crate) mod temp_table;
pub(crate) mod values;
pub(crate) mod vrf_key;

View File

@@ -0,0 +1,116 @@
use ms_database::{IntoRow, TokenRow};
use tracing::{debug, error};
use uuid::Uuid;
use crate::{
ms_sql::{
migrations::TABLE_PUBLISH_QUEUE,
sql_params::SqlParams,
tables::{
akd_storable_for_ms_sql::{QueryStatement, Statement},
temp_table::TempTable,
},
},
publish_queue::{PublishQueueError, PublishQueueItem},
};
pub fn enqueue_statement(raw_label: Vec<u8>, raw_value: Vec<u8>) -> Statement {
debug!("Building enqueue_statement for publish queue");
let mut params = SqlParams::new();
params.add("id", Box::new(uuid::Uuid::now_v7()));
params.add("raw_label", Box::new(raw_label));
params.add("raw_value", Box::new(raw_value));
let sql = format!(
r#"
INSERT INTO {}
(id, raw_label, raw_value)
VALUES ({}, {}, {})"#,
TABLE_PUBLISH_QUEUE,
params
.key_for("id")
.expect("id was added to the params list"),
params
.key_for("raw_label")
.expect("raw_label was added to the params list"),
params
.key_for("raw_value")
.expect("raw_value was added to the params list"),
);
Statement::new(sql, params)
}
pub fn peek_statement(limit: isize) -> QueryStatement<PublishQueueItem, PublishQueueError> {
debug!("Building peek_statement for publish queue");
let sql = format!(
r#"
SELECT TOP {} id, raw_label, raw_value
FROM {}
ORDER BY id ASC"#,
limit, TABLE_PUBLISH_QUEUE
);
QueryStatement::new(sql, SqlParams::new(), |row: &ms_database::Row| {
let id: uuid::Uuid = row.get("id").ok_or_else(|| {
error!("id is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_label: &[u8] = row.get("raw_label").ok_or_else(|| {
error!("raw_label is NULL or missing in publish queue row");
PublishQueueError
})?;
let raw_value: &[u8] = row.get("raw_value").ok_or_else(|| {
error!("raw_value is NULL or missing in publish queue row");
PublishQueueError
})?;
Ok(PublishQueueItem {
id,
raw_label: raw_label.to_vec(),
raw_value: raw_value.to_vec(),
})
})
}
pub fn bulk_delete_rows(ids: &'_ [Uuid]) -> Result<Vec<TokenRow<'_>>, PublishQueueError> {
debug!("Building bulk_delete_rows for publish queue");
let mut rows = Vec::new();
for id in ids {
let row = (id.clone()).into_row();
rows.push(row);
}
Ok(rows)
}
pub fn bulk_delete_statement(temp_table_name: &str) -> Statement {
debug!("Building bulk_delete_statement deleting ids in temp table from the publish queue");
let sql = format!(
r#"
DELETE pq
FROM {} pq
INNER JOIN {} temp ON pq.id = temp.id
"#,
TABLE_PUBLISH_QUEUE, temp_table_name
);
Statement::new(sql, SqlParams::new())
}
// pub fn delete_statement(ids: Vec<uuid::Uuid>) -> Statement {
// debug!("Building delete_statement for publish queue");
// let mut params = SqlParams::new();
// let mut id_placeholders = Vec::new();
// for (i, id) in ids.iter().enumerate() {
// let param_name = format!("id_{}", i);
// params.add(&param_name, Box::new(*id));
// id_placeholders.push(params.key_for(&param_name).expect("id was added to params"));
// }
// let sql = format!(
// r#"
// DELETE FROM {}
// WHERE id IN ({})"#,
// TABLE_PUBLISH_QUEUE,
// id_placeholders.join(", ")
// );
// Statement::new(sql, params)
// }

View File

@@ -7,6 +7,7 @@ pub(crate) enum TempTable {
HistoryTreeNodes,
Values,
RawLabelSearch,
PublishQueueIds,
}
impl std::fmt::Display for TempTable {
@@ -17,6 +18,7 @@ impl std::fmt::Display for TempTable {
TempTable::HistoryTreeNodes => write!(f, "{TEMP_HISTORY_TREE_NODES_TABLE}"),
TempTable::Values => write!(f, "{TEMP_VALUES_TABLE}"),
TempTable::RawLabelSearch => write!(f, "{TEMP_SEARCH_LABELS_TABLE}"),
TempTable::PublishQueueIds => write!(f, "{TEMP_PUBLISH_QUEUE_IDS_TABLE}"),
}
}
}
@@ -113,7 +115,7 @@ impl TempTable {
);
"#
),
}
},
TempTable::RawLabelSearch => format!(
r#"
CREATE TABLE {TEMP_SEARCH_LABELS_TABLE} (
@@ -121,7 +123,14 @@ impl TempTable {
PRIMARY KEY (raw_label)
);
"#
)
),
TempTable::PublishQueueIds => format!(
r#"
CREATE TABLE {TEMP_PUBLISH_QUEUE_IDS_TABLE} (
id UNIQUEIDENTIFIER NOT NULL PRIMARY KEY
);
"#
),
}
}
}
@@ -141,3 +150,4 @@ pub(crate) const TEMP_SEARCH_LABELS_TABLE: &str = "#akd_temp_search_labels";
pub(crate) const TEMP_AZKS_TABLE: &str = "#akd_temp_azks";
pub(crate) const TEMP_HISTORY_TREE_NODES_TABLE: &str = "#akd_temp_history_tree_nodes";
pub(crate) const TEMP_VALUES_TABLE: &str = "#akd_temp_values";
pub(crate) const TEMP_PUBLISH_QUEUE_IDS_TABLE: &str = "#akd_temp_publish_queue_ids";

View File

@@ -11,7 +11,7 @@ use crate::ms_sql::{
tables::akd_storable_for_ms_sql::QueryStatement,
};
pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState, StorageError> {
debug!("Building get_all query for label (label not logged for privacy)");
let mut params = SqlParams::new();
// the raw vector is the key for value storage
@@ -34,7 +34,7 @@ pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
pub fn get_by_flag(
raw_label: &AkdLabel,
flag: ValueStateRetrievalFlag,
) -> QueryStatement<ValueState> {
) -> QueryStatement<ValueState, StorageError> {
debug!(?flag, "Building get_by_flag query with flag");
let mut params = SqlParams::new();
params.add("raw_label", Box::new(raw_label.0.clone()));
@@ -102,7 +102,7 @@ pub fn get_by_flag(
pub fn get_versions_by_flag(
temp_table_name: &str,
flag: ValueStateRetrievalFlag,
) -> QueryStatement<LabelVersion> {
) -> QueryStatement<LabelVersion, StorageError> {
let mut params = SqlParams::new();
let (filter, epoch_col) = match flag {

View File

@@ -1,5 +1,5 @@
use akd::errors::StorageError;
use tracing::debug;
use thiserror::Error;
use tracing::{debug, error};
use crate::{
ms_sql::{
@@ -11,7 +11,11 @@ use crate::{
vrf_key_database::VrfKeyTableData,
};
pub fn get_first_root_key_hash() -> QueryStatement<Vec<u8>> {
#[derive(Debug, Error)]
#[error("VRF Key Storage Error: {0}")]
pub struct VrfKeyStorageError(String);
pub fn get_first_root_key_hash() -> QueryStatement<Vec<u8>, VrfKeyStorageError> {
debug!("Building has_vrf_key statement");
let sql = format!(
r#"
@@ -21,16 +25,17 @@ pub fn get_first_root_key_hash() -> QueryStatement<Vec<u8>> {
TABLE_VRF_KEYS
);
QueryStatement::new(sql, SqlParams::new(), |row: &ms_database::Row| {
let hash: &[u8] = row
.get("root_key_hash")
.ok_or_else(|| StorageError::Other("root_key_hash is NULL or missing".to_string()))?;
let hash: &[u8] = row.get("root_key_hash").ok_or_else(|| {
error!("root_key_hash is NULL or missing");
VrfKeyStorageError("root_key_hash is NULL or missing".to_string())
})?;
Ok(hash.to_vec())
})
}
pub fn get_statement(
config: &VrfKeyConfig,
) -> Result<QueryStatement<VrfKeyTableData>, VrfRootKeyError> {
) -> Result<QueryStatement<VrfKeyTableData, VrfKeyStorageError>, VrfRootKeyError> {
debug!("Building get_statement for vrf key");
let mut params = SqlParams::new();
params.add("root_key_type", Box::new(config.root_key_type() as i16));
@@ -55,19 +60,23 @@ pub fn get_statement(
Ok(QueryStatement::new(sql, params, from_row))
}
pub fn from_row(row: &ms_database::Row) -> Result<VrfKeyTableData, StorageError> {
let root_key_hash: &[u8] = row
.get("root_key_hash")
.ok_or_else(|| StorageError::Other("Missing root_key_hash column".to_string()))?;
let root_key_type: i16 = row
.get("root_key_type")
.ok_or_else(|| StorageError::Other("root_key_type is NULL or missing".to_string()))?;
pub fn from_row(row: &ms_database::Row) -> Result<VrfKeyTableData, VrfKeyStorageError> {
let root_key_hash: &[u8] = row.get("root_key_hash").ok_or_else(|| {
error!("root_key_hash is NULL or missing");
VrfKeyStorageError("root_key_hash is NULL or missing".to_string())
})?;
let root_key_type: i16 = row.get("root_key_type").ok_or_else(|| {
error!("root_key_type is NULL of missing");
VrfKeyStorageError("root_key_type is NULL or missing".to_string())
})?;
let enc_sym_key: Option<&[u8]> = row.get("enc_sym_key");
let sym_enc_vrf_key: &[u8] = row
.get("sym_enc_vrf_key")
.ok_or_else(|| StorageError::Other("sym_enc_vrf_key is NULL or missing".to_string()))?;
let sym_enc_vrf_key: &[u8] = row.get("sym_enc_vrf_key").ok_or_else(|| {
error!("sym_enc_vrf_key is NULL or missing");
VrfKeyStorageError("sym_enc_vrf_key is NULL or missing".to_string())
})?;
let sym_enc_vrf_key_nonce: &[u8] = row.get("sym_enc_vrf_key_nonce").ok_or_else(|| {
StorageError::Other("sym_enc_vrf_key_nonce is NULL or missing".to_string())
error!("sym_enc_vrf_key_nonce is NULL or missing");
VrfKeyStorageError("sym_enc_vrf_key_nonce is NULL or missing".to_string())
})?;
Ok(VrfKeyTableData {

View File

@@ -0,0 +1,64 @@
use async_trait::async_trait;
use thiserror::Error;
use crate::ms_sql::MsSql;
pub(crate) struct PublishQueueItem {
pub id: uuid::Uuid,
pub raw_label: Vec<u8>,
pub raw_value: Vec<u8>,
}
#[derive(Debug, Error)]
#[error("Publish queue error")]
pub struct PublishQueueError;
#[async_trait]
pub trait PublishQueue {
async fn enqueue(
&self,
raw_label: Vec<u8>,
raw_value: Vec<u8>,
) -> Result<(), PublishQueueError>;
async fn peek(&self, limit: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError>;
async fn remove(&self, ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError>;
}
#[derive(Debug, Clone)]
pub enum PublishQueueType {
MsSql(MsSql),
}
#[async_trait]
impl PublishQueue for PublishQueueType {
async fn enqueue(
&self,
_raw_label: Vec<u8>,
_raw_value: Vec<u8>,
) -> Result<(), PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement enqueue logic for MsSql
Ok(())
}
}
}
async fn peek(&self, _max: isize) -> Result<Vec<PublishQueueItem>, PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement peek logic for MsSql
Ok(vec![])
}
}
}
async fn remove(&self, _ids: Vec<uuid::Uuid>) -> Result<(), PublishQueueError> {
match self {
PublishQueueType::MsSql(_ms_sql) => {
// Implement remove logic for MsSql
Ok(())
}
}
}
}

View File

@@ -1,4 +1,4 @@
use akd_storage::AkdDatabase;
use akd_storage::{AkdDatabase, PublishQueueType};
use axum::routing::{get, post};
use common::BitAkdDirectory;
@@ -9,6 +9,7 @@ mod publish;
pub struct AppState {
pub directory: BitAkdDirectory,
pub db: AkdDatabase,
pub publish_queue: PublishQueueType,
}
pub fn api_routes() -> axum::Router<AppState> {

View File

@@ -1,4 +1,5 @@
use super::AppState;
use akd_storage::PublishQueue;
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use tracing::{error, info, instrument};
@@ -16,7 +17,7 @@ pub struct PublishResponse {
#[instrument(skip_all)]
pub async fn publish_handler(
State(AppState { directory, .. }): State<AppState>,
State(AppState { publish_queue, .. }): State<AppState>,
Json(request): Json<PublishRequest>,
) -> impl IntoResponse {
info!("Handling publish request");
@@ -25,6 +26,13 @@ pub async fn publish_handler(
let akd_value: Vec<u8> = request.akd_value_b64.into_bytes();
//TODO: enqueue publish operation to to_publish queue
if let Err(e) = publish_queue.enqueue(akd_label, akd_value).await {
error!("Failed to enqueue publish request: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(PublishResponse { success: false }),
);
}
Json(PublishResponse { success: true })
(StatusCode::OK, Json(PublishResponse { success: true }))
}