From 929527a5c2a1522c8bc220e2ad5d80f4d10bd74e Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Thu, 18 Dec 2025 16:42:27 -0800 Subject: [PATCH] Clean up akd storage interface for application usage --- akd/Cargo.lock | 1 + akd/Cargo.toml | 1 + akd/crates/akd_storage/Cargo.toml | 1 + .../20251211_01_create_vrf_key_table/down.sql | 2 +- .../20251211_01_create_vrf_key_table/up.sql | 12 +- akd/crates/akd_storage/src/akd_database.rs | 101 ++++++++ .../akd_storage/src/akd_storage_config.rs | 58 ++++- akd/crates/akd_storage/src/db_config.rs | 9 +- akd/crates/akd_storage/src/lib.rs | 88 +------ .../akd_storage/src/ms_sql/migrations.rs | 1 + akd/crates/akd_storage/src/ms_sql/mod.rs | 72 +++++- .../akd_storage/src/ms_sql/tables/vrf_key.rs | 103 +++++++- akd/crates/akd_storage/src/vrf_key_config.rs | 64 +++++ .../akd_storage/src/vrf_key_database.rs | 234 ++++++++++++------ akd/crates/akd_test_utility/src/main.rs | 47 ++-- akd/crates/common/src/config/mod.rs | 43 ---- .../src/config/storage_manager_config.rs | 59 ----- akd/crates/common/src/lib.rs | 2 +- akd/crates/publisher/src/lib.rs | 1 - akd/crates/reader/src/lib.rs | 3 +- 20 files changed, 588 insertions(+), 314 deletions(-) create mode 100644 akd/crates/akd_storage/src/akd_database.rs delete mode 100644 akd/crates/common/src/config/mod.rs delete mode 100644 akd/crates/common/src/config/storage_manager_config.rs diff --git a/akd/Cargo.lock b/akd/Cargo.lock index 56992bb4a3..30f7956686 100644 --- a/akd/Cargo.lock +++ b/akd/Cargo.lock @@ -94,6 +94,7 @@ dependencies = [ "blake3", "chacha20poly1305", "ed25519-dalek", + "hex", "ms_database", "rand 0.8.5", "rsa", diff --git a/akd/Cargo.toml b/akd/Cargo.toml index fe2e9204c2..a5a15c6ad0 100644 --- a/akd/Cargo.toml +++ b/akd/Cargo.toml @@ -21,6 +21,7 @@ bitwarden-akd-configuration = { path = "crates/bitwarden-akd-configuration" } blake3 = "1.8.2" common = { path = "crates/common" } config = "0.15.18" +hex = "0.4.3" serde = { version = "1.0.228", features = ["derive"] } tokio = { version = "1.47.1", features = ["full"] } tracing = { version = "0.1.41" } diff --git a/akd/crates/akd_storage/Cargo.toml b/akd/crates/akd_storage/Cargo.toml index 7c5433c770..49ce4caef2 100644 --- a/akd/crates/akd_storage/Cargo.toml +++ b/akd/crates/akd_storage/Cargo.toml @@ -14,6 +14,7 @@ bitwarden-encoding = { path = "../bitwarden-encoding" } blake3.workspace = true chacha20poly1305 = { version = "0.10.1" } ed25519-dalek = { version = ">=2.1.1, <=2.2.0", features = ["rand_core"] } +hex.workspace = true ms_database = { path = "../ms_database" } rand = ">=0.8.5, <0.9" rsa = { version = ">=0.9.2, <0.10" } diff --git a/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/down.sql b/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/down.sql index 148c3115e4..fe87a4fa28 100644 --- a/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/down.sql +++ b/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/down.sql @@ -1 +1 @@ -DROP TABLE IF EXISTS dbo.vrf_key; +DROP TABLE IF EXISTS dbo.akd_vrf_keys; diff --git a/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/up.sql b/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/up.sql index 1f53748837..04deca8631 100644 --- a/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/up.sql +++ b/akd/crates/akd_storage/migrations/ms_sql/20251211_01_create_vrf_key_table/up.sql @@ -1,11 +1,11 @@ -IF OBJECT_ID('dbo.vrf_key', 'U') IS NULL +IF OBJECT_ID('dbo.akd_vrf_keys', 'U') IS NULL BEGIN - CREATE TABLE dbo.vrf_key ( + CREATE TABLE dbo.akd_vrf_keys ( root_key_hash VARBINARY(32) NOT NULL, - root_key_type INT NOT NULL, - enc_sym_key VARBINARY(32) NULL, - enc_sym_key_nonce VARBINARY(24) NULL, - sym_enc_vrf_key VARBINARY(32) NOT NULL, + root_key_type SMALLINT NOT NULL, + enc_sym_key VARBINARY(max) NULL, + sym_enc_vrf_key VARBINARY(48) NOT NULL, + sym_enc_vrf_key_nonce VARBINARY(24) NULL, PRIMARY KEY (root_key_hash, root_key_type) ); END diff --git a/akd/crates/akd_storage/src/akd_database.rs b/akd/crates/akd_storage/src/akd_database.rs new file mode 100644 index 0000000000..d3ef929fe7 --- /dev/null +++ b/akd/crates/akd_storage/src/akd_database.rs @@ -0,0 +1,101 @@ +use std::collections::HashMap; + +use akd::{ + errors::StorageError, + storage::{ + types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag}, + Database, DbSetState, Storable, + }, + AkdLabel, AkdValue, +}; +use async_trait::async_trait; + +use crate::{ + db_config::DatabaseType, vrf_key_config::VrfKeyConfig, vrf_key_database::VrfKeyDatabase, +}; + +#[derive(Debug, Clone)] +pub struct AkdDatabase { + db: DatabaseType, + vrf_key_config: VrfKeyConfig, +} + +impl AkdDatabase { + pub fn db(&self) -> &DatabaseType { + &self.db + } + + pub fn new(db: DatabaseType, vrf_key_config: VrfKeyConfig) -> AkdDatabase { + AkdDatabase { db, vrf_key_config } + } + + pub fn vrf_key_database(&self) -> VrfKeyDatabase { + VrfKeyDatabase::new(self.db.clone(), self.vrf_key_config.clone()) + } +} + +#[async_trait] +impl Database for AkdDatabase { + async fn set(&self, record: DbRecord) -> Result<(), StorageError> { + match &self.db { + DatabaseType::MsSql(db) => db.set(record).await, + } + } + + async fn batch_set( + &self, + records: Vec, + state: DbSetState, // TODO: unused in mysql example, but may be needed later + ) -> Result<(), StorageError> { + match &self.db { + DatabaseType::MsSql(db) => db.batch_set(records, state).await, + } + } + + async fn get(&self, id: &St::StorageKey) -> Result { + match &self.db { + DatabaseType::MsSql(db) => db.get::(id).await, + } + } + + async fn batch_get( + &self, + ids: &[St::StorageKey], + ) -> Result, StorageError> { + match &self.db { + DatabaseType::MsSql(db) => db.batch_get::(ids).await, + } + } + + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. + async fn get_user_data(&self, raw_label: &AkdLabel) -> Result { + match &self.db { + DatabaseType::MsSql(db) => db.get_user_data(raw_label).await, + } + } + + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. + async fn get_user_state( + &self, + raw_label: &AkdLabel, + flag: ValueStateRetrievalFlag, + ) -> Result { + match &self.db { + DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await, + } + } + + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. + async fn get_user_state_versions( + &self, + raw_labels: &[AkdLabel], + flag: ValueStateRetrievalFlag, + ) -> Result, StorageError> { + match &self.db { + DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await, + } + } +} diff --git a/akd/crates/akd_storage/src/akd_storage_config.rs b/akd/crates/akd_storage/src/akd_storage_config.rs index 17a1eff1fc..44bae0c2a1 100644 --- a/akd/crates/akd_storage/src/akd_storage_config.rs +++ b/akd/crates/akd_storage/src/akd_storage_config.rs @@ -1,19 +1,65 @@ -use serde::Deserialize; +use std::time::Duration; -use crate::db_config::DbConfig; +use akd::storage::StorageManager; +use serde::Deserialize; +use thiserror::Error; +use tracing::error; + +use crate::{db_config::DbConfig, vrf_key_config::VrfKeyConfig, AkdDatabase}; #[derive(Debug, Clone, Deserialize)] pub struct AkdStorageConfig { - _db_config: DbConfig, + pub db_config: DbConfig, /// Controls how long items stay in cache before being evicted (in milliseconds). Defaults to 30 seconds. #[serde(default = "default_cache_item_lifetime_ms")] - _cache_item_lifetime_ms: usize, + pub cache_item_lifetime_ms: usize, /// Controls the maximum size of the cache in bytes. Defaults to no limit. #[serde(default)] - _cache_limit_bytes: Option, + pub cache_limit_bytes: Option, /// Controls how often the cache is cleaned (in milliseconds). Defaults to 15 seconds. #[serde(default = "default_cache_clean_ms")] - _cache_clean_ms: usize, + pub cache_clean_ms: usize, + pub vrf_key_config: VrfKeyConfig, +} + +#[derive(Debug, Error)] +#[error("Failed to initialize storage")] +pub struct AkdStorageInitializationError; + +impl AkdStorageConfig { + pub async fn initialize_storage( + &self, + ) -> Result<(StorageManager, AkdDatabase), AkdStorageInitializationError> { + let db = self.db_config.connect().await.map_err(|err| { + error!(%err, "Failed to connect to database"); + AkdStorageInitializationError + })?; + + let state = AkdDatabase::new(db, self.vrf_key_config.clone()); + + let cache_item_lifetime = Some(Duration::from_millis( + self.cache_item_lifetime_ms.try_into().map_err(|err| { + error!(%err, "Cache item lifetime out of range"); + AkdStorageInitializationError + })?, + )); + let cache_clean_frequency = Some(Duration::from_millis( + self.cache_clean_ms.try_into().map_err(|err| { + error!(%err, "Cache clean interval out of range"); + AkdStorageInitializationError + })?, + )); + + Ok(( + StorageManager::new( + state.clone(), + cache_item_lifetime, + self.cache_limit_bytes, + cache_clean_frequency, + ), + state, + )) + } } fn default_cache_item_lifetime_ms() -> usize { diff --git a/akd/crates/akd_storage/src/db_config.rs b/akd/crates/akd_storage/src/db_config.rs index 236d44830d..fd37853ba0 100644 --- a/akd/crates/akd_storage/src/db_config.rs +++ b/akd/crates/akd_storage/src/db_config.rs @@ -1,7 +1,7 @@ use akd::errors::StorageError; use serde::{Deserialize, Serialize}; -use crate::DatabaseType; +use crate::ms_sql::MsSql; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] @@ -12,6 +12,13 @@ pub enum DbConfig { }, } +/// Enum to represent different database types supported by the storage layer. +/// Each variant is cheap to clone for reuse across threads. +#[derive(Debug, Clone)] +pub enum DatabaseType { + MsSql(MsSql), +} + impl DbConfig { pub async fn connect(&self) -> Result { let db = match self { diff --git a/akd/crates/akd_storage/src/lib.rs b/akd/crates/akd_storage/src/lib.rs index ce438f6d64..7e864ea894 100644 --- a/akd/crates/akd_storage/src/lib.rs +++ b/akd/crates/akd_storage/src/lib.rs @@ -1,92 +1,8 @@ -use std::collections::HashMap; - -use akd::{ - errors::StorageError, - storage::{ - types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag}, - Database, DbSetState, Storable, - }, - AkdLabel, AkdValue, -}; -use async_trait::async_trait; - -use crate::ms_sql::MsSql; - +mod akd_database; pub mod akd_storage_config; pub mod db_config; pub mod ms_sql; pub mod vrf_key_config; pub mod vrf_key_database; -/// Enum to represent different database types supported by the storage layer. -/// Each variant is cheap to clone for reuse across threads. -#[derive(Debug, Clone)] -pub enum DatabaseType { - MsSql(MsSql), -} - -#[async_trait] -impl Database for DatabaseType { - async fn set(&self, record: DbRecord) -> Result<(), StorageError> { - match self { - DatabaseType::MsSql(db) => db.set(record).await, - } - } - - async fn batch_set( - &self, - records: Vec, - state: DbSetState, // TODO: unused in mysql example, but may be needed later - ) -> Result<(), StorageError> { - match self { - DatabaseType::MsSql(db) => db.batch_set(records, state).await, - } - } - - async fn get(&self, id: &St::StorageKey) -> Result { - match self { - DatabaseType::MsSql(db) => db.get::(id).await, - } - } - - async fn batch_get( - &self, - ids: &[St::StorageKey], - ) -> Result, StorageError> { - match self { - DatabaseType::MsSql(db) => db.batch_get::(ids).await, - } - } - - // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's - // too restrictive for what we want, so generalize the name a bit. - async fn get_user_data(&self, raw_label: &AkdLabel) -> Result { - match self { - DatabaseType::MsSql(db) => db.get_user_data(raw_label).await, - } - } - - // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's - // too restrictive for what we want, so generalize the name a bit. - async fn get_user_state( - &self, - raw_label: &AkdLabel, - flag: ValueStateRetrievalFlag, - ) -> Result { - match self { - DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await, - } - } - - // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's - // too restrictive for what we want, so generalize the name a bit. - async fn get_user_state_versions( - &self, - raw_labels: &[AkdLabel], - flag: ValueStateRetrievalFlag, - ) -> Result, StorageError> { - match self { - DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await, - } - } -} +pub use akd_database::*; diff --git a/akd/crates/akd_storage/src/ms_sql/migrations.rs b/akd/crates/akd_storage/src/ms_sql/migrations.rs index 7d254b11b9..625abf9edf 100644 --- a/akd/crates/akd_storage/src/ms_sql/migrations.rs +++ b/akd/crates/akd_storage/src/ms_sql/migrations.rs @@ -4,6 +4,7 @@ use ms_database::{load_migrations, Migration}; pub const TABLE_AZKS: &str = "akd_azks"; pub const TABLE_HISTORY_TREE_NODES: &str = "akd_history_tree_nodes"; pub const TABLE_VALUES: &str = "akd_values"; +pub const TABLE_VRF_KEYS: &str = "akd_vrf_keys"; pub const TABLE_MIGRATIONS: &str = ms_database::TABLE_MIGRATIONS; pub(crate) const MIGRATIONS: &[Migration] = load_migrations!("migrations/ms_sql"); diff --git a/akd/crates/akd_storage/src/ms_sql/mod.rs b/akd/crates/akd_storage/src/ms_sql/mod.rs index 13acd7f9ff..c75a56aa96 100644 --- a/akd/crates/akd_storage/src/ms_sql/mod.rs +++ b/akd/crates/akd_storage/src/ms_sql/mod.rs @@ -18,6 +18,7 @@ use tracing::{debug, error, info, instrument, trace, warn}; use migrations::{ MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES, + TABLE_VRF_KEYS, }; use tables::{ akd_storable_for_ms_sql::{AkdStorableForMsSql, Statement}, @@ -25,6 +26,12 @@ use tables::{ values, }; +use crate::{ + ms_sql::tables::vrf_key, + vrf_key_config::VrfKeyConfig, + vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData}, +}; + const DEFAULT_POOL_SIZE: u32 = 100; pub struct MsSqlBuilder { @@ -116,6 +123,7 @@ impl MsSql { DROP TABLE IF EXISTS {TABLE_AZKS}; DROP TABLE IF EXISTS {TABLE_HISTORY_TREE_NODES}; DROP TABLE IF EXISTS {TABLE_VALUES}; + DROP TABLE IF EXISTS {TABLE_VRF_KEYS}; DROP TABLE IF EXISTS {TABLE_MIGRATIONS};"# ); @@ -164,6 +172,66 @@ impl MsSql { } } +impl MsSql { + #[instrument(skip(self, config), level = "debug")] + pub async fn get_vrf_key( + &self, + config: &VrfKeyConfig, + ) -> Result { + debug!("Retrieving VRF key from database"); + + let mut conn = self.get_connection().await.map_err(|err| { + error!(%err, "Failed to get DB connection for VRF key retrieval"); + VrfKeyRetrievalError::DatabaseError + })?; + + let statement = vrf_key::get_statement(&config).map_err(|err| { + error!(%err, "Failed to build VRF key retrieval statement"); + VrfKeyRetrievalError::CorruptedData + })?; + let query_stream = conn + .query(statement.sql(), &statement.params()) + .await + .map_err(|err| { + error!(%err, "Failed to execute VRF key retrieval query"); + VrfKeyRetrievalError::DatabaseError + })?; + + let row = query_stream.into_row().await.map_err(|err| { + error!(%err, "Failed to fetch VRF key row"); + VrfKeyRetrievalError::DatabaseError + })?; + + match row { + None => { + debug!("VRF key not found"); + return Err(VrfKeyRetrievalError::KeyNotFound); + } + Some(row) => { + debug!("VRF key found"); + vrf_key::from_row(&row).map_err(|err| { + error!(%err, "Failed to parse VRF key from database row"); + VrfKeyRetrievalError::CorruptedData + }) + } + } + } + + #[instrument(skip(self, table_data), level = "debug")] + pub async fn store_vrf_key( + &self, + table_data: &VrfKeyTableData, + ) -> Result<(), VrfKeyStorageError> { + debug!("Storing VRF key in database"); + + let statement = vrf_key::store_statement(table_data); + self.execute_statement(&statement).await.map_err(|err| { + error!(%err, "Failed to store VRF key in database"); + VrfKeyStorageError + }) + } +} + #[async_trait] impl Database for MsSql { #[instrument(skip(self, record), level = "debug")] @@ -552,7 +620,9 @@ impl Database for MsSql { statement.parse(&row) } else { debug!("Raw label not found"); - Err(StorageError::NotFound("ValueState for label not found".to_string())) + Err(StorageError::NotFound( + "ValueState for label not found".to_string(), + )) } } diff --git a/akd/crates/akd_storage/src/ms_sql/tables/vrf_key.rs b/akd/crates/akd_storage/src/ms_sql/tables/vrf_key.rs index 2211c0d925..cf6e8b4413 100644 --- a/akd/crates/akd_storage/src/ms_sql/tables/vrf_key.rs +++ b/akd/crates/akd_storage/src/ms_sql/tables/vrf_key.rs @@ -1,6 +1,101 @@ -use crate::vrf_key_database::VrfKeyTableData; +use akd::errors::StorageError; +use tracing::debug; -#[allow(unused)] -pub async fn get_vrf_key(root_key: &[u8]) -> VrfKeyTableData { - todo!() +use crate::{ + ms_sql::{ + migrations::TABLE_VRF_KEYS, + sql_params::SqlParams, + tables::akd_storable_for_ms_sql::{QueryStatement, Statement}, + }, + vrf_key_config::{VrfKeyConfig, VrfRootKeyError}, + vrf_key_database::VrfKeyTableData, +}; + +pub fn get_statement( + config: &VrfKeyConfig, +) -> Result, VrfRootKeyError> { + debug!("Building get_statement for vrf key"); + let mut params = SqlParams::new(); + params.add("root_key_type", Box::new(config.root_key_type() as i16)); + params.add( + "root_key_hash", + Box::new(config.root_key_hash().expect("valid root key hash")), + ); + + let sql = format!( + r#" + SELECT root_key_hash, root_key_type, enc_sym_key, sym_enc_vrf_key, sym_enc_vrf_key_nonce + FROM {} + WHERE root_key_type = {} AND root_key_hash = {}"#, + TABLE_VRF_KEYS, + params + .key_for("root_key_type") + .expect("root_key_type was added to the params list"), + params + .key_for("root_key_hash") + .expect("root_key_hash was added to the params list"), + ); + Ok(QueryStatement::new(sql, params, from_row)) +} + +pub fn from_row(row: &ms_database::Row) -> Result { + let root_key_hash: &[u8] = row + .get("root_key_hash") + .ok_or_else(|| StorageError::Other("Missing root_key_hash column".to_string()))?; + let root_key_type: i16 = row + .get("root_key_type") + .ok_or_else(|| StorageError::Other("root_key_type is NULL or missing".to_string()))?; + let enc_sym_key: Option<&[u8]> = row.get("enc_sym_key"); + let sym_enc_vrf_key: &[u8] = row + .get("sym_enc_vrf_key") + .ok_or_else(|| StorageError::Other("sym_enc_vrf_key is NULL or missing".to_string()))?; + let sym_enc_vrf_key_nonce: &[u8] = row.get("sym_enc_vrf_key_nonce").ok_or_else(|| { + StorageError::Other("sym_enc_vrf_key_nonce is NULL or missing".to_string()) + })?; + + Ok(VrfKeyTableData { + root_key_hash: root_key_hash.to_vec(), + root_key_type: root_key_type.into(), + enc_sym_key: enc_sym_key.map(|k| k.to_vec()), + sym_enc_vrf_key: sym_enc_vrf_key.to_vec(), + sym_enc_vrf_key_nonce: sym_enc_vrf_key_nonce.to_vec(), + }) +} + +pub fn store_statement(table_data: &VrfKeyTableData) -> Statement { + debug!("Building store_statement for vrf key"); + let mut params = SqlParams::new(); + params.add("root_key_hash", Box::new(table_data.root_key_hash.clone())); + params.add( + "root_key_type", + Box::new(Into::::into(table_data.root_key_type)), + ); + params.add("enc_sym_key", Box::new(table_data.enc_sym_key.clone())); + params.add( + "sym_enc_vrf_key", + Box::new(table_data.sym_enc_vrf_key.clone()), + ); + params.add( + "sym_enc_vrf_key_nonce", + Box::new(table_data.sym_enc_vrf_key_nonce.clone()), + ); + let sql = format!( + r#" + MERGE INTO dbo.{TABLE_VRF_KEYS} AS t + USING (SELECT {}) AS source + ON t.root_key_hash = source.root_key_hash AND t.root_key_type = source.root_key_type + WHEN MATCHED THEN + UPDATE SET {} + WHEN NOT MATCHED THEN + INSERT ({}) + VALUES ({});"#, + params.keys_as_columns().join(", "), + params + .set_columns_equal_except("t.", "source.", vec!["root_key_hash", "root_key_type"]) + .join(", "), + params.columns().join(", "), + params.keys().join(", "), + ); + + Statement::new(sql, params) } diff --git a/akd/crates/akd_storage/src/vrf_key_config.rs b/akd/crates/akd_storage/src/vrf_key_config.rs index 9da2971660..e7a34dd0bd 100644 --- a/akd/crates/akd_storage/src/vrf_key_config.rs +++ b/akd/crates/akd_storage/src/vrf_key_config.rs @@ -1,4 +1,11 @@ +use std::str::FromStr; + +use rsa::pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey}; use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tracing::error; + +use crate::vrf_key_database::VrfRootKeyType; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] @@ -36,3 +43,60 @@ pub enum VrfKeyConfig { /// Losing this key is equivalent to losing your directory's VRF key. PEMEncodedRSAKey { private_key: String }, } + +#[derive(Debug, Error)] +#[error("Error reading root key from configuration")] +pub struct VrfRootKeyError; + +impl VrfKeyConfig { + pub fn root_key_bytes(&self) -> Result, VrfRootKeyError> { + match self { + #[cfg(test)] + VrfKeyConfig::ConstantVrfKey => { + // This is the hard coded vrf key + Ok( + hex::decode("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721") + .map_err(|err| { + error!("Failed to decode hardcoded vrf key: {}", err); + VrfRootKeyError + })?, + ) + } + VrfKeyConfig::B64EncodedSymmetricKey { key } => bitwarden_encoding::B64::from_str(&key) + .map_err(|err| { + error!(%err, "Failed to decode symmetric key from base64 format"); + VrfRootKeyError + }) + .map(|b64| Vec::::from(b64)), + VrfKeyConfig::PEMEncodedRSAKey { private_key } => { + let rsa_private_key = + rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| { + error!(%err, "Failed to decode RSA private key from PEM format"); + VrfRootKeyError + })?; + Ok(rsa_private_key + .to_pkcs1_der() + .map_err(|err| { + error!(%err, "Failed to encode RSA private key to DER format"); + VrfRootKeyError + })? + .as_bytes() + .to_vec()) + } + } + } + + pub fn root_key_hash(&self) -> Result, VrfRootKeyError> { + let root_key_bytes = self.root_key_bytes()?; + Ok(blake3::hash(&root_key_bytes).as_bytes().to_vec()) + } + + pub fn root_key_type(&self) -> VrfRootKeyType { + match self { + #[cfg(test)] + VrfKeyConfig::ConstantVrfKey => VrfRootKeyType::None, + VrfKeyConfig::B64EncodedSymmetricKey { .. } => VrfRootKeyType::SymmetricKey, + VrfKeyConfig::PEMEncodedRSAKey { .. } => VrfRootKeyType::RsaKey, + } + } +} diff --git a/akd/crates/akd_storage/src/vrf_key_database.rs b/akd/crates/akd_storage/src/vrf_key_database.rs index 8fd755df51..6bf0dded27 100644 --- a/akd/crates/akd_storage/src/vrf_key_database.rs +++ b/akd/crates/akd_storage/src/vrf_key_database.rs @@ -1,17 +1,14 @@ -use std::str::FromStr; - +use akd::ecvrf::{VRFKeyStorage, VrfError}; +use async_trait::async_trait; use chacha20poly1305::{ aead::{generic_array::GenericArray, Aead}, AeadCore, KeyInit, XChaCha20Poly1305, }; -use rsa::{ - pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey}, - Pkcs1v15Encrypt, -}; +use rsa::{pkcs1::DecodeRsaPrivateKey, Pkcs1v15Encrypt}; use thiserror::Error; use tracing::{error, info}; -use crate::vrf_key_config::VrfKeyConfig; +use crate::{db_config::DatabaseType, vrf_key_config::VrfKeyConfig}; /// Represents a storage-layer error #[derive(Debug, Error)] @@ -35,31 +32,116 @@ pub struct VrfKeyCreationError; #[error("Internal VRF key storage error")] pub struct VrfKeyStorageError; -#[allow(unused)] -trait VrfKeyTable { - async fn get_vrf_key(config: VrfKeyConfig) -> Result; - async fn store_vrf_key(table_data: VrfKeyTableData) -> Result<(), VrfKeyStorageError>; +#[derive(Debug, Clone)] +pub struct VrfKeyDatabase { + db: DatabaseType, + vrf_key_config: VrfKeyConfig, + cached_vrf_key: Option>, +} + +impl VrfKeyDatabase { + pub fn new(db: DatabaseType, config: VrfKeyConfig) -> VrfKeyDatabase { + VrfKeyDatabase { + db, + vrf_key_config: config, + cached_vrf_key: None, + } + } + async fn get_vrf_key(&self) -> Result { + match &self.db { + DatabaseType::MsSql(db) => db.get_vrf_key(&self.vrf_key_config).await, + } + } + async fn store_vrf_key(&self, table_data: &VrfKeyTableData) -> Result<(), VrfKeyStorageError> { + match &self.db { + DatabaseType::MsSql(db) => db.store_vrf_key(table_data).await, + } + } +} + +#[async_trait] +impl VRFKeyStorage for VrfKeyDatabase { + async fn retrieve(&self) -> Result, VrfError> { + if let Some(cached_key) = &self.cached_vrf_key { + return Ok(cached_key.clone()); + } + + match &self.get_vrf_key().await { + Ok(table_data) => table_data + .to_vrf_key(&self.vrf_key_config) + .await + .map(|k| k.0) + .map_err(|err| { + VrfError::SigningKey(format!("Error decrypting signing key: {err}")) + }), + Err(VrfKeyRetrievalError::KeyNotFound) => { + // Make a new key + let (table_data, key) = + VrfKeyTableData::new(&self.vrf_key_config) + .await + .map_err(|err| { + VrfError::SigningKey(format!("Failed to create a new VRF key: {err}")) + })?; + // store store + self.store_vrf_key(&table_data).await.map_err(|err| { + VrfError::SigningKey(format!("Failed to store a new VRF key: {err}")) + })?; + + // and return + Ok(key.0) + } + Err(err) => { + error!(%err, "Key retrieval error"); + Err(VrfError::SigningKey("Key retrieval error".to_string())) + } + } + } } pub struct VrfKeyTableData { pub root_key_hash: Vec, - pub root_key_type: RootKeyType, + pub root_key_type: VrfRootKeyType, pub enc_sym_key: Option>, pub sym_enc_vrf_key: Vec, pub sym_enc_vrf_key_nonce: Vec, } -pub enum RootKeyType { +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum VrfRootKeyType { #[cfg(test)] None = 0, SymmetricKey = 1, RsaKey = 2, } +impl From for VrfRootKeyType { + fn from(value: i16) -> Self { + match value { + 1 => VrfRootKeyType::SymmetricKey, + 2 => VrfRootKeyType::RsaKey, + #[cfg(test)] + 0 => VrfRootKeyType::None, + _ => panic!("Invalid VrfRootKeyType value: {}", value), + } + } +} + +impl From for i16 { + fn from(value: VrfRootKeyType) -> Self { + match value { + VrfRootKeyType::SymmetricKey => 1, + VrfRootKeyType::RsaKey => 2, + #[cfg(test)] + VrfRootKeyType::None => 0, + } + } +} + +#[derive(Clone)] pub struct VrfKey(pub Vec); impl VrfKeyTableData { - pub async fn new(config: VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> { + pub async fn new(config: &VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> { info!("Generating new VRF key and table data"); // handle constant key case separately to avoid unnecessary key generation / parsing #[cfg(test)] @@ -68,8 +150,8 @@ impl VrfKeyTableData { return Ok(( VrfKeyTableData { - root_key_hash: vec![], - root_key_type: RootKeyType::None, + root_key_hash: config.root_key_hash().expect("hard coded vrf key"), + root_key_type: VrfRootKeyType::None, enc_sym_key: None, sym_enc_vrf_key: vec![], sym_enc_vrf_key_nonce: vec![], @@ -78,17 +160,14 @@ impl VrfKeyTableData { )); } - let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key } = &config { - let raw_key = bitwarden_encoding::B64::from_str(key).map_err(|err| { - error!(%err, "Invalid b64 encoding of symmetric key"); - VrfKeyCreationError - })?; + let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key: _ } = &config { + let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?; ( - XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| { + XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| { error!(%err, "Invalid symmetric key length"); VrfKeyCreationError })?, - raw_key.as_bytes().to_vec(), + raw_key, ) } else { let key = XChaCha20Poly1305::generate_key(rand::thread_rng()); @@ -103,16 +182,22 @@ impl VrfKeyTableData { VrfKeyCreationError })?; - match config { + match &config { #[cfg(test)] VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => { - let root_key_hash = blake3::hash(&sym_key).as_bytes().to_vec(); + let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?; + error!( + rkh = root_key_hash.len(), + sevk = sym_enc_vrf_key.len(), + sevkn = nonce.len(), + "lengths of stuff!!!!!\n\n\n\n" + ); Ok(( VrfKeyTableData { root_key_hash, - root_key_type: RootKeyType::SymmetricKey, + root_key_type: VrfRootKeyType::SymmetricKey, enc_sym_key: None, sym_enc_vrf_key, sym_enc_vrf_key_nonce: nonce.to_vec(), @@ -122,21 +207,11 @@ impl VrfKeyTableData { } VrfKeyConfig::PEMEncodedRSAKey { private_key } => { let rsa_private_key = - rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| { + rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| { error!(%err, "Failed to decode RSA private key from PEM format"); VrfKeyCreationError })?; - let root_key_hash = blake3::hash( - rsa_private_key - .to_pkcs1_der() - .map_err(|err| { - error!(%err, "Failed to encode RSA private key to DER format"); - VrfKeyCreationError - })? - .as_bytes(), - ) - .as_bytes() - .to_vec(); + let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?; let rsa_public_key = rsa_private_key.to_public_key(); let enc_sym_key = rsa_public_key .encrypt(&mut rand::thread_rng(), Pkcs1v15Encrypt, &sym_key) @@ -148,7 +223,7 @@ impl VrfKeyTableData { Ok(( VrfKeyTableData { root_key_hash, - root_key_type: RootKeyType::RsaKey, + root_key_type: VrfRootKeyType::RsaKey, enc_sym_key: Some(enc_sym_key), sym_enc_vrf_key, sym_enc_vrf_key_nonce: nonce.to_vec(), @@ -159,7 +234,7 @@ impl VrfKeyTableData { } } - pub async fn to_vrf_key(&self, config: VrfKeyConfig) -> Result { + pub async fn to_vrf_key(&self, config: &VrfKeyConfig) -> Result { info!("Decrypting VrfKeyTableData to obtain VRF key"); // handle constant key case separately to avoid unnecessary key generation / parsing #[cfg(test)] @@ -179,15 +254,12 @@ impl VrfKeyTableData { return Err(VrfKeyCreationError); } let nonce = GenericArray::from_slice(self.sym_enc_vrf_key_nonce.as_ref()); - let vrf_key = match config { + let vrf_key = match &config { #[cfg(test)] VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above - VrfKeyConfig::B64EncodedSymmetricKey { key } => { - let raw_key = bitwarden_encoding::B64::from_str(&key).map_err(|err| { - error!(%err, "Invalid b64 encoding of symmetric key"); - VrfKeyCreationError - })?; - let sym = XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| { + VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => { + let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?; + let sym = XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| { error!(%err, "Invalid symmetric key length"); VrfKeyCreationError })?; @@ -202,7 +274,7 @@ impl VrfKeyTableData { } VrfKeyConfig::PEMEncodedRSAKey { private_key } => { let rsa_private_key = - rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| { + rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| { error!(%err, "Failed to decode RSA private key from PEM format"); VrfKeyCreationError })?; @@ -303,8 +375,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_generation_from_symmetric_key() { let config = create_test_symmetric_config(); - let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap(); - let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap(); + let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap(); + let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap(); assert_eq!(table_data.enc_sym_key, None); assert_eq!( @@ -322,8 +394,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= pub async fn test_generation_from_rsa_key() { let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap(); let config = create_test_rsa_config(); - let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap(); - let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap(); + let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap(); + let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap(); assert_eq!( table_data.root_key_hash, [ @@ -343,8 +415,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_generation_from_constant_key() { let config = super::VrfKeyConfig::ConstantVrfKey; - let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap(); - let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap(); + let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap(); + let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap(); assert_eq!(table_data.root_key_hash, vec![]); assert_eq!(table_data.enc_sym_key, None); @@ -360,19 +432,19 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= key: "not!valid@base64#".to_string(), }; - let result = super::VrfKeyTableData::new(config.clone()).await; + let result = super::VrfKeyTableData::new(&config.clone()).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } #[tokio::test] pub async fn test_invalid_base64_during_retrieval() { let config_valid = create_test_symmetric_config(); - let (table_data, _) = super::VrfKeyTableData::new(config_valid).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config_valid).await.unwrap(); let config_invalid = super::VrfKeyConfig::B64EncodedSymmetricKey { key: "not!valid@base64#".to_string(), }; - let result = table_data.to_vrf_key(config_invalid).await; + let result = table_data.to_vrf_key(&config_invalid).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } @@ -381,7 +453,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= let short_key = bitwarden_encoding::B64::from(vec![0u8; 16]).to_string(); let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: short_key }; - let result = super::VrfKeyTableData::new(config).await; + let result = super::VrfKeyTableData::new(&config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } @@ -393,7 +465,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= private_key: malformed_pem.to_string(), }; - let result = super::VrfKeyTableData::new(config).await; + let result = super::VrfKeyTableData::new(&config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } @@ -406,62 +478,62 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= private_key: missing_headers.to_string(), }; - let result = super::VrfKeyTableData::new(config).await; + let result = super::VrfKeyTableData::new(&config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } #[tokio::test] pub async fn test_wrong_symmetric_key_decryption() { let config1 = create_test_symmetric_config(); - let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap(); let config2 = super::VrfKeyConfig::B64EncodedSymmetricKey { key: generate_random_symmetric_key_b64(), }; - let result = table_data.to_vrf_key(config2).await; + let result = table_data.to_vrf_key(&config2).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } #[tokio::test] pub async fn test_wrong_rsa_key_decryption() { let config1 = create_test_rsa_config(); - let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap(); let config2 = super::VrfKeyConfig::PEMEncodedRSAKey { private_key: generate_random_rsa_key_pem(), }; - let result = table_data.to_vrf_key(config2).await; + let result = table_data.to_vrf_key(&config2).await; assert!(result.is_err()); } #[tokio::test] pub async fn test_rsa_missing_enc_sym_key() { let config = create_test_rsa_config(); - let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap(); + let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); table_data.enc_sym_key = None; - let result = table_data.to_vrf_key(config).await; + let result = table_data.to_vrf_key(&config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } #[tokio::test] pub async fn test_wrong_nonce() { let config = create_test_symmetric_config(); - let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap(); + let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); table_data.sym_enc_vrf_key_nonce.truncate(10); - let result = table_data.to_vrf_key(config).await; + let result = table_data.to_vrf_key(&config).await; assert!(result.is_err()); } #[tokio::test] pub async fn test_nonce_size_validation() { let config = create_test_symmetric_config(); - let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); assert_eq!(table_data.sym_enc_vrf_key_nonce.len(), 24); } @@ -470,7 +542,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= pub async fn test_empty_symmetric_key() { let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: String::new() }; - let result = super::VrfKeyTableData::new(config).await; + let result = super::VrfKeyTableData::new(&config).await; assert!(result.is_err()); } @@ -480,7 +552,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= private_key: String::new(), }; - let result = super::VrfKeyTableData::new(config).await; + let result = super::VrfKeyTableData::new(&config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } @@ -489,8 +561,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= let config1 = create_test_symmetric_config(); let config2 = create_test_symmetric_config(); - let (table_data1, vrf_key1) = super::VrfKeyTableData::new(config1).await.unwrap(); - let (table_data2, vrf_key2) = super::VrfKeyTableData::new(config2).await.unwrap(); + let (table_data1, vrf_key1) = super::VrfKeyTableData::new(&config1).await.unwrap(); + let (table_data2, vrf_key2) = super::VrfKeyTableData::new(&config2).await.unwrap(); assert_ne!(vrf_key1.0, vrf_key2.0); assert_ne!( @@ -502,7 +574,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_symmetric_key_not_persisted() { let config = create_test_symmetric_config(); - let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); let symmetric_key_bytes = bitwarden_encoding::B64::from_str(TEST_SYMMETRIC_KEY_B64).unwrap(); @@ -516,7 +588,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_rsa_private_key_not_persisted() { let config = create_test_rsa_config(); - let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap(); let rsa_der = rsa_key.to_pkcs1_der().unwrap(); @@ -532,7 +604,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_vrf_key_is_encrypted_at_rest() { let config = create_test_symmetric_config(); - let (table_data, vrf_key) = super::VrfKeyTableData::new(config).await.unwrap(); + let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap(); assert!(!table_data .sym_enc_vrf_key @@ -543,7 +615,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_symmetric_key_encryption_in_rsa_mode() { let config = create_test_rsa_config(); - let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap(); let enc_sym_key = table_data.enc_sym_key.as_ref().unwrap(); let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap(); @@ -556,22 +628,22 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII= #[tokio::test] pub async fn test_cannot_decrypt_symmetric_with_rsa_config() { let sym_config = create_test_symmetric_config(); - let (table_data, _) = super::VrfKeyTableData::new(sym_config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&sym_config).await.unwrap(); let rsa_config = create_test_rsa_config(); - let result = table_data.to_vrf_key(rsa_config).await; + let result = table_data.to_vrf_key(&rsa_config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } #[tokio::test] pub async fn test_cannot_decrypt_rsa_with_symmetric_config() { let rsa_config = create_test_rsa_config(); - let (table_data, _) = super::VrfKeyTableData::new(rsa_config).await.unwrap(); + let (table_data, _) = super::VrfKeyTableData::new(&rsa_config).await.unwrap(); let sym_config = create_test_symmetric_config(); - let result = table_data.to_vrf_key(sym_config).await; + let result = table_data.to_vrf_key(&sym_config).await; assert!(matches!(result, Err(super::VrfKeyCreationError))); } } diff --git a/akd/crates/akd_test_utility/src/main.rs b/akd/crates/akd_test_utility/src/main.rs index fa56d33899..bdceecdfcb 100644 --- a/akd/crates/akd_test_utility/src/main.rs +++ b/akd/crates/akd_test_utility/src/main.rs @@ -1,8 +1,7 @@ -use akd::ecvrf::HardCodedAkdVRF; -use akd::storage::StorageManager; use akd::Directory; -use akd_storage::db_config::DbConfig; -use akd_storage::DatabaseType; +use akd_storage::akd_storage_config::AkdStorageConfig; +use akd_storage::db_config::{DatabaseType, DbConfig}; +use akd_storage::AkdDatabase; use anyhow::{Context, Result}; use clap::{Parser, ValueEnum}; use commands::Command; @@ -159,33 +158,37 @@ async fn main() -> Result<()> { // Create database connection info!("Connecting to MS SQL database"); - let config = DbConfig::MsSql { - connection_string, - pool_size: args.pool_size, + let config = AkdStorageConfig { + db_config: DbConfig::MsSql { + connection_string: connection_string.clone(), + pool_size: args.pool_size, + }, + cache_item_lifetime_ms: 30000, + cache_limit_bytes: None, + cache_clean_ms: 15000, + vrf_key_config: akd_storage::vrf_key_config::VrfKeyConfig::B64EncodedSymmetricKey { + key: "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=".to_string(), + }, }; - let db = config - .connect() + let (storage_manager, state) = config + .initialize_storage() .await - .context("Failed to connect to database")?; + .context("Failed to initialize storage")?; // Handle pre-processing modes - if let Some(()) = pre_process_mode(&args, &db).await? { + if let Some(()) = pre_process_mode(&args, &state.db()).await? { return Ok(()); } - let storage_manager = StorageManager::new(db.clone(), None, None, None); - let vrf = HardCodedAkdVRF {}; - let mut directory = Directory::::new(storage_manager.clone(), vrf) + let mut directory = Directory::::new(storage_manager, state.vrf_key_database()) .await .context("Failed to create AKD directory")?; let (tx, mut rx) = channel(2); - tokio::spawn(async move { - directory_host::init_host::(&mut rx, &mut directory).await - }); + tokio::spawn(async move { directory_host::init_host(&mut rx, &mut directory).await }); - process_mode(&args, &tx, &db).await?; + process_mode(&args, &tx, &state).await?; Ok(()) } @@ -221,7 +224,7 @@ async fn pre_process_mode(args: &CliArgs, db: &DatabaseType) -> Result, - db: &DatabaseType, + db: &AkdDatabase, ) -> Result<()> { if let Some(mode) = &args.mode { match mode { @@ -260,7 +263,7 @@ async fn process_mode( Ok(()) } -async fn bench_db_insert(num_users: u64, db: &DatabaseType) -> Result<()> { +async fn bench_db_insert(num_users: u64, db: &AkdDatabase) -> Result<()> { use owo_colors::OwoColorize; println!("{}", "======= Benchmark operation requested =======".cyan()); @@ -488,7 +491,7 @@ async fn bench_lookup( async fn repl_loop( _args: &CliArgs, tx: &Sender, - db: &DatabaseType, + db: &AkdDatabase, ) -> Result<()> { loop { println!("Please enter a command"); @@ -498,7 +501,7 @@ async fn repl_loop( let mut line = String::new(); stdin().read_line(&mut line)?; - match (db, Command::parse(&mut line)) { + match (db.db(), Command::parse(&mut line)) { (_, Command::Unknown(other)) => { println!("Input '{other}' is not supported, enter 'help' for the help menu") } diff --git a/akd/crates/common/src/config/mod.rs b/akd/crates/common/src/config/mod.rs deleted file mode 100644 index f0efb0d00f..0000000000 --- a/akd/crates/common/src/config/mod.rs +++ /dev/null @@ -1,43 +0,0 @@ -use thiserror::Error; - -mod storage_manager_config; - -pub use akd_storage::vrf_key_config::VrfKeyConfig; -pub use storage_manager_config::*; - -#[derive(Error, Debug)] -pub enum ConfigError { - #[error("Failed to connect to database")] - DatabaseConnection(#[source] akd::errors::StorageError), - - #[error("Configuration value 'cache_item_lifetime_ms' is invalid: value {value} exceeds maximum allowed ({max})")] - CacheLifetimeOutOfRange { - value: usize, - max: u64, - #[source] - source: std::num::TryFromIntError, - }, - - #[error("Configuration value 'cache_clean_frequency_ms' is invalid: value {value} exceeds maximum allowed ({max})")] - CacheCleanFrequencyOutOfRange { - value: usize, - max: u64, - #[source] - source: std::num::TryFromIntError, - }, - - #[error("{0}")] - Custom(String), - - #[error("Invalid hex string for VRF key material")] - InvalidVrfKeyMaterialHex(#[source] hex::FromHexError), - - #[error("VRF key material must be exactly 32 bytes, got {actual} bytes")] - VrfKeyMaterialInvalidLength { actual: usize }, -} - -impl ConfigError { - pub fn new(message: impl Into) -> Self { - Self::Custom(message.into()) - } -} diff --git a/akd/crates/common/src/config/storage_manager_config.rs b/akd/crates/common/src/config/storage_manager_config.rs deleted file mode 100644 index 6876471658..0000000000 --- a/akd/crates/common/src/config/storage_manager_config.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::time::Duration; - -use crate::config::ConfigError; -use akd::storage::StorageManager; -use akd_storage::{db_config::DbConfig, DatabaseType}; -use serde::{Deserialize, Serialize}; - -/// items live for 30s by default -pub const DEFAULT_ITEM_LIFETIME_MS: usize = 30_000; -/// clean the cache every 15s by default -pub const DEFAULT_CACHE_CLEAN_FREQUENCY_MS: usize = 15_000; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct StorageManagerConfig { - pub db_config: DbConfig, - pub cache_limit_bytes: Option, - #[serde(default = "default_cache_item_lifetime_ms")] - pub cache_item_lifetime_ms: usize, - #[serde(default = "default_cache_clean_frequency_ms")] - pub cache_clean_frequency_ms: usize, -} - -fn default_cache_item_lifetime_ms() -> usize { - DEFAULT_ITEM_LIFETIME_MS -} - -fn default_cache_clean_frequency_ms() -> usize { - DEFAULT_CACHE_CLEAN_FREQUENCY_MS -} - -impl StorageManagerConfig { - pub async fn create(&self) -> Result, ConfigError> { - Ok(StorageManager::new( - self.db_config - .connect() - .await - .map_err(ConfigError::DatabaseConnection)?, - Some(Duration::from_millis( - self.cache_item_lifetime_ms.try_into().map_err(|source| { - ConfigError::CacheLifetimeOutOfRange { - value: self.cache_item_lifetime_ms, - max: u64::MAX, - source, - } - })?, - )), - self.cache_limit_bytes, - Some(Duration::from_millis( - self.cache_clean_frequency_ms.try_into().map_err(|source| { - ConfigError::CacheCleanFrequencyOutOfRange { - value: self.cache_clean_frequency_ms, - max: u64::MAX, - source, - } - })?, - )), - )) - } -} diff --git a/akd/crates/common/src/lib.rs b/akd/crates/common/src/lib.rs index ef68c36943..8b13789179 100644 --- a/akd/crates/common/src/lib.rs +++ b/akd/crates/common/src/lib.rs @@ -1 +1 @@ -pub mod config; + diff --git a/akd/crates/publisher/src/lib.rs b/akd/crates/publisher/src/lib.rs index 44e9ccab93..d58866cd83 100644 --- a/akd/crates/publisher/src/lib.rs +++ b/akd/crates/publisher/src/lib.rs @@ -1,7 +1,6 @@ use akd::{directory::Directory, storage::StorageManager}; use akd_storage::DatabaseType; use bitwarden_akd_configuration::BitwardenV1Configuration; -use common::VrfStorageType; use tracing::instrument; struct AppState { diff --git a/akd/crates/reader/src/lib.rs b/akd/crates/reader/src/lib.rs index 80058ba582..eb51c83f24 100644 --- a/akd/crates/reader/src/lib.rs +++ b/akd/crates/reader/src/lib.rs @@ -1,7 +1,6 @@ use akd::{directory::ReadOnlyDirectory, storage::StorageManager}; use akd_storage::DatabaseType; use bitwarden_akd_configuration::BitwardenV1Configuration; -use common::VrfStorageType; use tracing::instrument; struct AppState { @@ -10,7 +9,7 @@ struct AppState { } #[instrument(skip_all, name = "reader_start")] -pub async fn start(db: DatabaseType, vrf: VrfStorageType) { +pub async fn start(db: DatabaseType) { let storage_manager = StorageManager::new_no_cache(db); let _app = AppState { _directory: ReadOnlyDirectory::new(storage_manager, vrf)