mirror of
https://github.com/bitwarden/server
synced 2026-02-04 10:43:53 +00:00
Clean up akd storage interface for application usage
This commit is contained in:
1
akd/Cargo.lock
generated
1
akd/Cargo.lock
generated
@@ -94,6 +94,7 @@ dependencies = [
|
||||
"blake3",
|
||||
"chacha20poly1305",
|
||||
"ed25519-dalek",
|
||||
"hex",
|
||||
"ms_database",
|
||||
"rand 0.8.5",
|
||||
"rsa",
|
||||
|
||||
@@ -21,6 +21,7 @@ bitwarden-akd-configuration = { path = "crates/bitwarden-akd-configuration" }
|
||||
blake3 = "1.8.2"
|
||||
common = { path = "crates/common" }
|
||||
config = "0.15.18"
|
||||
hex = "0.4.3"
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
tokio = { version = "1.47.1", features = ["full"] }
|
||||
tracing = { version = "0.1.41" }
|
||||
|
||||
@@ -14,6 +14,7 @@ bitwarden-encoding = { path = "../bitwarden-encoding" }
|
||||
blake3.workspace = true
|
||||
chacha20poly1305 = { version = "0.10.1" }
|
||||
ed25519-dalek = { version = ">=2.1.1, <=2.2.0", features = ["rand_core"] }
|
||||
hex.workspace = true
|
||||
ms_database = { path = "../ms_database" }
|
||||
rand = ">=0.8.5, <0.9"
|
||||
rsa = { version = ">=0.9.2, <0.10" }
|
||||
|
||||
@@ -1 +1 @@
|
||||
DROP TABLE IF EXISTS dbo.vrf_key;
|
||||
DROP TABLE IF EXISTS dbo.akd_vrf_keys;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
IF OBJECT_ID('dbo.vrf_key', 'U') IS NULL
|
||||
IF OBJECT_ID('dbo.akd_vrf_keys', 'U') IS NULL
|
||||
BEGIN
|
||||
CREATE TABLE dbo.vrf_key (
|
||||
CREATE TABLE dbo.akd_vrf_keys (
|
||||
root_key_hash VARBINARY(32) NOT NULL,
|
||||
root_key_type INT NOT NULL,
|
||||
enc_sym_key VARBINARY(32) NULL,
|
||||
enc_sym_key_nonce VARBINARY(24) NULL,
|
||||
sym_enc_vrf_key VARBINARY(32) NOT NULL,
|
||||
root_key_type SMALLINT NOT NULL,
|
||||
enc_sym_key VARBINARY(max) NULL,
|
||||
sym_enc_vrf_key VARBINARY(48) NOT NULL,
|
||||
sym_enc_vrf_key_nonce VARBINARY(24) NULL,
|
||||
PRIMARY KEY (root_key_hash, root_key_type)
|
||||
);
|
||||
END
|
||||
|
||||
101
akd/crates/akd_storage/src/akd_database.rs
Normal file
101
akd/crates/akd_storage/src/akd_database.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use akd::{
|
||||
errors::StorageError,
|
||||
storage::{
|
||||
types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag},
|
||||
Database, DbSetState, Storable,
|
||||
},
|
||||
AkdLabel, AkdValue,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
db_config::DatabaseType, vrf_key_config::VrfKeyConfig, vrf_key_database::VrfKeyDatabase,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AkdDatabase {
|
||||
db: DatabaseType,
|
||||
vrf_key_config: VrfKeyConfig,
|
||||
}
|
||||
|
||||
impl AkdDatabase {
|
||||
pub fn db(&self) -> &DatabaseType {
|
||||
&self.db
|
||||
}
|
||||
|
||||
pub fn new(db: DatabaseType, vrf_key_config: VrfKeyConfig) -> AkdDatabase {
|
||||
AkdDatabase { db, vrf_key_config }
|
||||
}
|
||||
|
||||
pub fn vrf_key_database(&self) -> VrfKeyDatabase {
|
||||
VrfKeyDatabase::new(self.db.clone(), self.vrf_key_config.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for AkdDatabase {
|
||||
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.set(record).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn batch_set(
|
||||
&self,
|
||||
records: Vec<DbRecord>,
|
||||
state: DbSetState, // TODO: unused in mysql example, but may be needed later
|
||||
) -> Result<(), StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.batch_set(records, state).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.get::<St>(id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn batch_get<St: Storable>(
|
||||
&self,
|
||||
ids: &[St::StorageKey],
|
||||
) -> Result<Vec<DbRecord>, StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.batch_get::<St>(ids).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<KeyData, StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.get_user_data(raw_label).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state(
|
||||
&self,
|
||||
raw_label: &AkdLabel,
|
||||
flag: ValueStateRetrievalFlag,
|
||||
) -> Result<ValueState, StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state_versions(
|
||||
&self,
|
||||
raw_labels: &[AkdLabel],
|
||||
flag: ValueStateRetrievalFlag,
|
||||
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,65 @@
|
||||
use serde::Deserialize;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::db_config::DbConfig;
|
||||
use akd::storage::StorageManager;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{db_config::DbConfig, vrf_key_config::VrfKeyConfig, AkdDatabase};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AkdStorageConfig {
|
||||
_db_config: DbConfig,
|
||||
pub db_config: DbConfig,
|
||||
/// Controls how long items stay in cache before being evicted (in milliseconds). Defaults to 30 seconds.
|
||||
#[serde(default = "default_cache_item_lifetime_ms")]
|
||||
_cache_item_lifetime_ms: usize,
|
||||
pub cache_item_lifetime_ms: usize,
|
||||
/// Controls the maximum size of the cache in bytes. Defaults to no limit.
|
||||
#[serde(default)]
|
||||
_cache_limit_bytes: Option<usize>,
|
||||
pub cache_limit_bytes: Option<usize>,
|
||||
/// Controls how often the cache is cleaned (in milliseconds). Defaults to 15 seconds.
|
||||
#[serde(default = "default_cache_clean_ms")]
|
||||
_cache_clean_ms: usize,
|
||||
pub cache_clean_ms: usize,
|
||||
pub vrf_key_config: VrfKeyConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Failed to initialize storage")]
|
||||
pub struct AkdStorageInitializationError;
|
||||
|
||||
impl AkdStorageConfig {
|
||||
pub async fn initialize_storage(
|
||||
&self,
|
||||
) -> Result<(StorageManager<AkdDatabase>, AkdDatabase), AkdStorageInitializationError> {
|
||||
let db = self.db_config.connect().await.map_err(|err| {
|
||||
error!(%err, "Failed to connect to database");
|
||||
AkdStorageInitializationError
|
||||
})?;
|
||||
|
||||
let state = AkdDatabase::new(db, self.vrf_key_config.clone());
|
||||
|
||||
let cache_item_lifetime = Some(Duration::from_millis(
|
||||
self.cache_item_lifetime_ms.try_into().map_err(|err| {
|
||||
error!(%err, "Cache item lifetime out of range");
|
||||
AkdStorageInitializationError
|
||||
})?,
|
||||
));
|
||||
let cache_clean_frequency = Some(Duration::from_millis(
|
||||
self.cache_clean_ms.try_into().map_err(|err| {
|
||||
error!(%err, "Cache clean interval out of range");
|
||||
AkdStorageInitializationError
|
||||
})?,
|
||||
));
|
||||
|
||||
Ok((
|
||||
StorageManager::new(
|
||||
state.clone(),
|
||||
cache_item_lifetime,
|
||||
self.cache_limit_bytes,
|
||||
cache_clean_frequency,
|
||||
),
|
||||
state,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn default_cache_item_lifetime_ms() -> usize {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use akd::errors::StorageError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::DatabaseType;
|
||||
use crate::ms_sql::MsSql;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
@@ -12,6 +12,13 @@ pub enum DbConfig {
|
||||
},
|
||||
}
|
||||
|
||||
/// Enum to represent different database types supported by the storage layer.
|
||||
/// Each variant is cheap to clone for reuse across threads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DatabaseType {
|
||||
MsSql(MsSql),
|
||||
}
|
||||
|
||||
impl DbConfig {
|
||||
pub async fn connect(&self) -> Result<DatabaseType, StorageError> {
|
||||
let db = match self {
|
||||
|
||||
@@ -1,92 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use akd::{
|
||||
errors::StorageError,
|
||||
storage::{
|
||||
types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag},
|
||||
Database, DbSetState, Storable,
|
||||
},
|
||||
AkdLabel, AkdValue,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::ms_sql::MsSql;
|
||||
|
||||
mod akd_database;
|
||||
pub mod akd_storage_config;
|
||||
pub mod db_config;
|
||||
pub mod ms_sql;
|
||||
pub mod vrf_key_config;
|
||||
pub mod vrf_key_database;
|
||||
|
||||
/// Enum to represent different database types supported by the storage layer.
|
||||
/// Each variant is cheap to clone for reuse across threads.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DatabaseType {
|
||||
MsSql(MsSql),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for DatabaseType {
|
||||
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.set(record).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn batch_set(
|
||||
&self,
|
||||
records: Vec<DbRecord>,
|
||||
state: DbSetState, // TODO: unused in mysql example, but may be needed later
|
||||
) -> Result<(), StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.batch_set(records, state).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.get::<St>(id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn batch_get<St: Storable>(
|
||||
&self,
|
||||
ids: &[St::StorageKey],
|
||||
) -> Result<Vec<DbRecord>, StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.batch_get::<St>(ids).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<KeyData, StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.get_user_data(raw_label).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state(
|
||||
&self,
|
||||
raw_label: &AkdLabel,
|
||||
flag: ValueStateRetrievalFlag,
|
||||
) -> Result<ValueState, StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await,
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state_versions(
|
||||
&self,
|
||||
raw_labels: &[AkdLabel],
|
||||
flag: ValueStateRetrievalFlag,
|
||||
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
|
||||
match self {
|
||||
DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use akd_database::*;
|
||||
|
||||
@@ -4,6 +4,7 @@ use ms_database::{load_migrations, Migration};
|
||||
pub const TABLE_AZKS: &str = "akd_azks";
|
||||
pub const TABLE_HISTORY_TREE_NODES: &str = "akd_history_tree_nodes";
|
||||
pub const TABLE_VALUES: &str = "akd_values";
|
||||
pub const TABLE_VRF_KEYS: &str = "akd_vrf_keys";
|
||||
pub const TABLE_MIGRATIONS: &str = ms_database::TABLE_MIGRATIONS;
|
||||
|
||||
pub(crate) const MIGRATIONS: &[Migration] = load_migrations!("migrations/ms_sql");
|
||||
|
||||
@@ -18,6 +18,7 @@ use tracing::{debug, error, info, instrument, trace, warn};
|
||||
|
||||
use migrations::{
|
||||
MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES,
|
||||
TABLE_VRF_KEYS,
|
||||
};
|
||||
use tables::{
|
||||
akd_storable_for_ms_sql::{AkdStorableForMsSql, Statement},
|
||||
@@ -25,6 +26,12 @@ use tables::{
|
||||
values,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
ms_sql::tables::vrf_key,
|
||||
vrf_key_config::VrfKeyConfig,
|
||||
vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData},
|
||||
};
|
||||
|
||||
const DEFAULT_POOL_SIZE: u32 = 100;
|
||||
|
||||
pub struct MsSqlBuilder {
|
||||
@@ -116,6 +123,7 @@ impl MsSql {
|
||||
DROP TABLE IF EXISTS {TABLE_AZKS};
|
||||
DROP TABLE IF EXISTS {TABLE_HISTORY_TREE_NODES};
|
||||
DROP TABLE IF EXISTS {TABLE_VALUES};
|
||||
DROP TABLE IF EXISTS {TABLE_VRF_KEYS};
|
||||
DROP TABLE IF EXISTS {TABLE_MIGRATIONS};"#
|
||||
);
|
||||
|
||||
@@ -164,6 +172,66 @@ impl MsSql {
|
||||
}
|
||||
}
|
||||
|
||||
impl MsSql {
|
||||
#[instrument(skip(self, config), level = "debug")]
|
||||
pub async fn get_vrf_key(
|
||||
&self,
|
||||
config: &VrfKeyConfig,
|
||||
) -> Result<VrfKeyTableData, VrfKeyRetrievalError> {
|
||||
debug!("Retrieving VRF key from database");
|
||||
|
||||
let mut conn = self.get_connection().await.map_err(|err| {
|
||||
error!(%err, "Failed to get DB connection for VRF key retrieval");
|
||||
VrfKeyRetrievalError::DatabaseError
|
||||
})?;
|
||||
|
||||
let statement = vrf_key::get_statement(&config).map_err(|err| {
|
||||
error!(%err, "Failed to build VRF key retrieval statement");
|
||||
VrfKeyRetrievalError::CorruptedData
|
||||
})?;
|
||||
let query_stream = conn
|
||||
.query(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, "Failed to execute VRF key retrieval query");
|
||||
VrfKeyRetrievalError::DatabaseError
|
||||
})?;
|
||||
|
||||
let row = query_stream.into_row().await.map_err(|err| {
|
||||
error!(%err, "Failed to fetch VRF key row");
|
||||
VrfKeyRetrievalError::DatabaseError
|
||||
})?;
|
||||
|
||||
match row {
|
||||
None => {
|
||||
debug!("VRF key not found");
|
||||
return Err(VrfKeyRetrievalError::KeyNotFound);
|
||||
}
|
||||
Some(row) => {
|
||||
debug!("VRF key found");
|
||||
vrf_key::from_row(&row).map_err(|err| {
|
||||
error!(%err, "Failed to parse VRF key from database row");
|
||||
VrfKeyRetrievalError::CorruptedData
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, table_data), level = "debug")]
|
||||
pub async fn store_vrf_key(
|
||||
&self,
|
||||
table_data: &VrfKeyTableData,
|
||||
) -> Result<(), VrfKeyStorageError> {
|
||||
debug!("Storing VRF key in database");
|
||||
|
||||
let statement = vrf_key::store_statement(table_data);
|
||||
self.execute_statement(&statement).await.map_err(|err| {
|
||||
error!(%err, "Failed to store VRF key in database");
|
||||
VrfKeyStorageError
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for MsSql {
|
||||
#[instrument(skip(self, record), level = "debug")]
|
||||
@@ -552,7 +620,9 @@ impl Database for MsSql {
|
||||
statement.parse(&row)
|
||||
} else {
|
||||
debug!("Raw label not found");
|
||||
Err(StorageError::NotFound("ValueState for label not found".to_string()))
|
||||
Err(StorageError::NotFound(
|
||||
"ValueState for label not found".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,101 @@
|
||||
use crate::vrf_key_database::VrfKeyTableData;
|
||||
use akd::errors::StorageError;
|
||||
use tracing::debug;
|
||||
|
||||
#[allow(unused)]
|
||||
pub async fn get_vrf_key(root_key: &[u8]) -> VrfKeyTableData {
|
||||
todo!()
|
||||
use crate::{
|
||||
ms_sql::{
|
||||
migrations::TABLE_VRF_KEYS,
|
||||
sql_params::SqlParams,
|
||||
tables::akd_storable_for_ms_sql::{QueryStatement, Statement},
|
||||
},
|
||||
vrf_key_config::{VrfKeyConfig, VrfRootKeyError},
|
||||
vrf_key_database::VrfKeyTableData,
|
||||
};
|
||||
|
||||
pub fn get_statement(
|
||||
config: &VrfKeyConfig,
|
||||
) -> Result<QueryStatement<VrfKeyTableData>, VrfRootKeyError> {
|
||||
debug!("Building get_statement for vrf key");
|
||||
let mut params = SqlParams::new();
|
||||
params.add("root_key_type", Box::new(config.root_key_type() as i16));
|
||||
params.add(
|
||||
"root_key_hash",
|
||||
Box::new(config.root_key_hash().expect("valid root key hash")),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT root_key_hash, root_key_type, enc_sym_key, sym_enc_vrf_key, sym_enc_vrf_key_nonce
|
||||
FROM {}
|
||||
WHERE root_key_type = {} AND root_key_hash = {}"#,
|
||||
TABLE_VRF_KEYS,
|
||||
params
|
||||
.key_for("root_key_type")
|
||||
.expect("root_key_type was added to the params list"),
|
||||
params
|
||||
.key_for("root_key_hash")
|
||||
.expect("root_key_hash was added to the params list"),
|
||||
);
|
||||
Ok(QueryStatement::new(sql, params, from_row))
|
||||
}
|
||||
|
||||
pub fn from_row(row: &ms_database::Row) -> Result<VrfKeyTableData, StorageError> {
|
||||
let root_key_hash: &[u8] = row
|
||||
.get("root_key_hash")
|
||||
.ok_or_else(|| StorageError::Other("Missing root_key_hash column".to_string()))?;
|
||||
let root_key_type: i16 = row
|
||||
.get("root_key_type")
|
||||
.ok_or_else(|| StorageError::Other("root_key_type is NULL or missing".to_string()))?;
|
||||
let enc_sym_key: Option<&[u8]> = row.get("enc_sym_key");
|
||||
let sym_enc_vrf_key: &[u8] = row
|
||||
.get("sym_enc_vrf_key")
|
||||
.ok_or_else(|| StorageError::Other("sym_enc_vrf_key is NULL or missing".to_string()))?;
|
||||
let sym_enc_vrf_key_nonce: &[u8] = row.get("sym_enc_vrf_key_nonce").ok_or_else(|| {
|
||||
StorageError::Other("sym_enc_vrf_key_nonce is NULL or missing".to_string())
|
||||
})?;
|
||||
|
||||
Ok(VrfKeyTableData {
|
||||
root_key_hash: root_key_hash.to_vec(),
|
||||
root_key_type: root_key_type.into(),
|
||||
enc_sym_key: enc_sym_key.map(|k| k.to_vec()),
|
||||
sym_enc_vrf_key: sym_enc_vrf_key.to_vec(),
|
||||
sym_enc_vrf_key_nonce: sym_enc_vrf_key_nonce.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn store_statement(table_data: &VrfKeyTableData) -> Statement {
|
||||
debug!("Building store_statement for vrf key");
|
||||
let mut params = SqlParams::new();
|
||||
params.add("root_key_hash", Box::new(table_data.root_key_hash.clone()));
|
||||
params.add(
|
||||
"root_key_type",
|
||||
Box::new(Into::<i16>::into(table_data.root_key_type)),
|
||||
);
|
||||
params.add("enc_sym_key", Box::new(table_data.enc_sym_key.clone()));
|
||||
params.add(
|
||||
"sym_enc_vrf_key",
|
||||
Box::new(table_data.sym_enc_vrf_key.clone()),
|
||||
);
|
||||
params.add(
|
||||
"sym_enc_vrf_key_nonce",
|
||||
Box::new(table_data.sym_enc_vrf_key_nonce.clone()),
|
||||
);
|
||||
let sql = format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_VRF_KEYS} AS t
|
||||
USING (SELECT {}) AS source
|
||||
ON t.root_key_hash = source.root_key_hash AND t.root_key_type = source.root_key_type
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET {}
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT ({})
|
||||
VALUES ({});"#,
|
||||
params.keys_as_columns().join(", "),
|
||||
params
|
||||
.set_columns_equal_except("t.", "source.", vec!["root_key_hash", "root_key_type"])
|
||||
.join(", "),
|
||||
params.columns().join(", "),
|
||||
params.keys().join(", "),
|
||||
);
|
||||
|
||||
Statement::new(sql, params)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use rsa::pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tracing::error;
|
||||
|
||||
use crate::vrf_key_database::VrfRootKeyType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
@@ -36,3 +43,60 @@ pub enum VrfKeyConfig {
|
||||
/// Losing this key is equivalent to losing your directory's VRF key.
|
||||
PEMEncodedRSAKey { private_key: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Error reading root key from configuration")]
|
||||
pub struct VrfRootKeyError;
|
||||
|
||||
impl VrfKeyConfig {
|
||||
pub fn root_key_bytes(&self) -> Result<Vec<u8>, VrfRootKeyError> {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
VrfKeyConfig::ConstantVrfKey => {
|
||||
// This is the hard coded vrf key
|
||||
Ok(
|
||||
hex::decode("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721")
|
||||
.map_err(|err| {
|
||||
error!("Failed to decode hardcoded vrf key: {}", err);
|
||||
VrfRootKeyError
|
||||
})?,
|
||||
)
|
||||
}
|
||||
VrfKeyConfig::B64EncodedSymmetricKey { key } => bitwarden_encoding::B64::from_str(&key)
|
||||
.map_err(|err| {
|
||||
error!(%err, "Failed to decode symmetric key from base64 format");
|
||||
VrfRootKeyError
|
||||
})
|
||||
.map(|b64| Vec::<u8>::from(b64)),
|
||||
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
|
||||
let rsa_private_key =
|
||||
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
|
||||
error!(%err, "Failed to decode RSA private key from PEM format");
|
||||
VrfRootKeyError
|
||||
})?;
|
||||
Ok(rsa_private_key
|
||||
.to_pkcs1_der()
|
||||
.map_err(|err| {
|
||||
error!(%err, "Failed to encode RSA private key to DER format");
|
||||
VrfRootKeyError
|
||||
})?
|
||||
.as_bytes()
|
||||
.to_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn root_key_hash(&self) -> Result<Vec<u8>, VrfRootKeyError> {
|
||||
let root_key_bytes = self.root_key_bytes()?;
|
||||
Ok(blake3::hash(&root_key_bytes).as_bytes().to_vec())
|
||||
}
|
||||
|
||||
pub fn root_key_type(&self) -> VrfRootKeyType {
|
||||
match self {
|
||||
#[cfg(test)]
|
||||
VrfKeyConfig::ConstantVrfKey => VrfRootKeyType::None,
|
||||
VrfKeyConfig::B64EncodedSymmetricKey { .. } => VrfRootKeyType::SymmetricKey,
|
||||
VrfKeyConfig::PEMEncodedRSAKey { .. } => VrfRootKeyType::RsaKey,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use akd::ecvrf::{VRFKeyStorage, VrfError};
|
||||
use async_trait::async_trait;
|
||||
use chacha20poly1305::{
|
||||
aead::{generic_array::GenericArray, Aead},
|
||||
AeadCore, KeyInit, XChaCha20Poly1305,
|
||||
};
|
||||
use rsa::{
|
||||
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
|
||||
Pkcs1v15Encrypt,
|
||||
};
|
||||
use rsa::{pkcs1::DecodeRsaPrivateKey, Pkcs1v15Encrypt};
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::vrf_key_config::VrfKeyConfig;
|
||||
use crate::{db_config::DatabaseType, vrf_key_config::VrfKeyConfig};
|
||||
|
||||
/// Represents a storage-layer error
|
||||
#[derive(Debug, Error)]
|
||||
@@ -35,31 +32,116 @@ pub struct VrfKeyCreationError;
|
||||
#[error("Internal VRF key storage error")]
|
||||
pub struct VrfKeyStorageError;
|
||||
|
||||
#[allow(unused)]
|
||||
trait VrfKeyTable {
|
||||
async fn get_vrf_key(config: VrfKeyConfig) -> Result<VrfKeyTableData, VrfKeyRetrievalError>;
|
||||
async fn store_vrf_key(table_data: VrfKeyTableData) -> Result<(), VrfKeyStorageError>;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VrfKeyDatabase {
|
||||
db: DatabaseType,
|
||||
vrf_key_config: VrfKeyConfig,
|
||||
cached_vrf_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl VrfKeyDatabase {
|
||||
pub fn new(db: DatabaseType, config: VrfKeyConfig) -> VrfKeyDatabase {
|
||||
VrfKeyDatabase {
|
||||
db,
|
||||
vrf_key_config: config,
|
||||
cached_vrf_key: None,
|
||||
}
|
||||
}
|
||||
async fn get_vrf_key(&self) -> Result<VrfKeyTableData, VrfKeyRetrievalError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.get_vrf_key(&self.vrf_key_config).await,
|
||||
}
|
||||
}
|
||||
async fn store_vrf_key(&self, table_data: &VrfKeyTableData) -> Result<(), VrfKeyStorageError> {
|
||||
match &self.db {
|
||||
DatabaseType::MsSql(db) => db.store_vrf_key(table_data).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VRFKeyStorage for VrfKeyDatabase {
|
||||
async fn retrieve(&self) -> Result<Vec<u8>, VrfError> {
|
||||
if let Some(cached_key) = &self.cached_vrf_key {
|
||||
return Ok(cached_key.clone());
|
||||
}
|
||||
|
||||
match &self.get_vrf_key().await {
|
||||
Ok(table_data) => table_data
|
||||
.to_vrf_key(&self.vrf_key_config)
|
||||
.await
|
||||
.map(|k| k.0)
|
||||
.map_err(|err| {
|
||||
VrfError::SigningKey(format!("Error decrypting signing key: {err}"))
|
||||
}),
|
||||
Err(VrfKeyRetrievalError::KeyNotFound) => {
|
||||
// Make a new key
|
||||
let (table_data, key) =
|
||||
VrfKeyTableData::new(&self.vrf_key_config)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
VrfError::SigningKey(format!("Failed to create a new VRF key: {err}"))
|
||||
})?;
|
||||
// store store
|
||||
self.store_vrf_key(&table_data).await.map_err(|err| {
|
||||
VrfError::SigningKey(format!("Failed to store a new VRF key: {err}"))
|
||||
})?;
|
||||
|
||||
// and return
|
||||
Ok(key.0)
|
||||
}
|
||||
Err(err) => {
|
||||
error!(%err, "Key retrieval error");
|
||||
Err(VrfError::SigningKey("Key retrieval error".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VrfKeyTableData {
|
||||
pub root_key_hash: Vec<u8>,
|
||||
pub root_key_type: RootKeyType,
|
||||
pub root_key_type: VrfRootKeyType,
|
||||
pub enc_sym_key: Option<Vec<u8>>,
|
||||
pub sym_enc_vrf_key: Vec<u8>,
|
||||
pub sym_enc_vrf_key_nonce: Vec<u8>,
|
||||
}
|
||||
|
||||
pub enum RootKeyType {
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub enum VrfRootKeyType {
|
||||
#[cfg(test)]
|
||||
None = 0,
|
||||
SymmetricKey = 1,
|
||||
RsaKey = 2,
|
||||
}
|
||||
|
||||
impl From<i16> for VrfRootKeyType {
|
||||
fn from(value: i16) -> Self {
|
||||
match value {
|
||||
1 => VrfRootKeyType::SymmetricKey,
|
||||
2 => VrfRootKeyType::RsaKey,
|
||||
#[cfg(test)]
|
||||
0 => VrfRootKeyType::None,
|
||||
_ => panic!("Invalid VrfRootKeyType value: {}", value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VrfRootKeyType> for i16 {
|
||||
fn from(value: VrfRootKeyType) -> Self {
|
||||
match value {
|
||||
VrfRootKeyType::SymmetricKey => 1,
|
||||
VrfRootKeyType::RsaKey => 2,
|
||||
#[cfg(test)]
|
||||
VrfRootKeyType::None => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VrfKey(pub Vec<u8>);
|
||||
|
||||
impl VrfKeyTableData {
|
||||
pub async fn new(config: VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> {
|
||||
pub async fn new(config: &VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> {
|
||||
info!("Generating new VRF key and table data");
|
||||
// handle constant key case separately to avoid unnecessary key generation / parsing
|
||||
#[cfg(test)]
|
||||
@@ -68,8 +150,8 @@ impl VrfKeyTableData {
|
||||
|
||||
return Ok((
|
||||
VrfKeyTableData {
|
||||
root_key_hash: vec![],
|
||||
root_key_type: RootKeyType::None,
|
||||
root_key_hash: config.root_key_hash().expect("hard coded vrf key"),
|
||||
root_key_type: VrfRootKeyType::None,
|
||||
enc_sym_key: None,
|
||||
sym_enc_vrf_key: vec![],
|
||||
sym_enc_vrf_key_nonce: vec![],
|
||||
@@ -78,17 +160,14 @@ impl VrfKeyTableData {
|
||||
));
|
||||
}
|
||||
|
||||
let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key } = &config {
|
||||
let raw_key = bitwarden_encoding::B64::from_str(key).map_err(|err| {
|
||||
error!(%err, "Invalid b64 encoding of symmetric key");
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key: _ } = &config {
|
||||
let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?;
|
||||
(
|
||||
XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| {
|
||||
XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| {
|
||||
error!(%err, "Invalid symmetric key length");
|
||||
VrfKeyCreationError
|
||||
})?,
|
||||
raw_key.as_bytes().to_vec(),
|
||||
raw_key,
|
||||
)
|
||||
} else {
|
||||
let key = XChaCha20Poly1305::generate_key(rand::thread_rng());
|
||||
@@ -103,16 +182,22 @@ impl VrfKeyTableData {
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
|
||||
match config {
|
||||
match &config {
|
||||
#[cfg(test)]
|
||||
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
|
||||
VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => {
|
||||
let root_key_hash = blake3::hash(&sym_key).as_bytes().to_vec();
|
||||
let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?;
|
||||
|
||||
error!(
|
||||
rkh = root_key_hash.len(),
|
||||
sevk = sym_enc_vrf_key.len(),
|
||||
sevkn = nonce.len(),
|
||||
"lengths of stuff!!!!!\n\n\n\n"
|
||||
);
|
||||
Ok((
|
||||
VrfKeyTableData {
|
||||
root_key_hash,
|
||||
root_key_type: RootKeyType::SymmetricKey,
|
||||
root_key_type: VrfRootKeyType::SymmetricKey,
|
||||
enc_sym_key: None,
|
||||
sym_enc_vrf_key,
|
||||
sym_enc_vrf_key_nonce: nonce.to_vec(),
|
||||
@@ -122,21 +207,11 @@ impl VrfKeyTableData {
|
||||
}
|
||||
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
|
||||
let rsa_private_key =
|
||||
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
|
||||
rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| {
|
||||
error!(%err, "Failed to decode RSA private key from PEM format");
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
let root_key_hash = blake3::hash(
|
||||
rsa_private_key
|
||||
.to_pkcs1_der()
|
||||
.map_err(|err| {
|
||||
error!(%err, "Failed to encode RSA private key to DER format");
|
||||
VrfKeyCreationError
|
||||
})?
|
||||
.as_bytes(),
|
||||
)
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?;
|
||||
let rsa_public_key = rsa_private_key.to_public_key();
|
||||
let enc_sym_key = rsa_public_key
|
||||
.encrypt(&mut rand::thread_rng(), Pkcs1v15Encrypt, &sym_key)
|
||||
@@ -148,7 +223,7 @@ impl VrfKeyTableData {
|
||||
Ok((
|
||||
VrfKeyTableData {
|
||||
root_key_hash,
|
||||
root_key_type: RootKeyType::RsaKey,
|
||||
root_key_type: VrfRootKeyType::RsaKey,
|
||||
enc_sym_key: Some(enc_sym_key),
|
||||
sym_enc_vrf_key,
|
||||
sym_enc_vrf_key_nonce: nonce.to_vec(),
|
||||
@@ -159,7 +234,7 @@ impl VrfKeyTableData {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn to_vrf_key(&self, config: VrfKeyConfig) -> Result<VrfKey, VrfKeyCreationError> {
|
||||
pub async fn to_vrf_key(&self, config: &VrfKeyConfig) -> Result<VrfKey, VrfKeyCreationError> {
|
||||
info!("Decrypting VrfKeyTableData to obtain VRF key");
|
||||
// handle constant key case separately to avoid unnecessary key generation / parsing
|
||||
#[cfg(test)]
|
||||
@@ -179,15 +254,12 @@ impl VrfKeyTableData {
|
||||
return Err(VrfKeyCreationError);
|
||||
}
|
||||
let nonce = GenericArray::from_slice(self.sym_enc_vrf_key_nonce.as_ref());
|
||||
let vrf_key = match config {
|
||||
let vrf_key = match &config {
|
||||
#[cfg(test)]
|
||||
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
|
||||
VrfKeyConfig::B64EncodedSymmetricKey { key } => {
|
||||
let raw_key = bitwarden_encoding::B64::from_str(&key).map_err(|err| {
|
||||
error!(%err, "Invalid b64 encoding of symmetric key");
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
let sym = XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| {
|
||||
VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => {
|
||||
let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?;
|
||||
let sym = XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| {
|
||||
error!(%err, "Invalid symmetric key length");
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
@@ -202,7 +274,7 @@ impl VrfKeyTableData {
|
||||
}
|
||||
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
|
||||
let rsa_private_key =
|
||||
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
|
||||
rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| {
|
||||
error!(%err, "Failed to decode RSA private key from PEM format");
|
||||
VrfKeyCreationError
|
||||
})?;
|
||||
@@ -303,8 +375,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_generation_from_symmetric_key() {
|
||||
let config = create_test_symmetric_config();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
|
||||
|
||||
assert_eq!(table_data.enc_sym_key, None);
|
||||
assert_eq!(
|
||||
@@ -322,8 +394,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
pub async fn test_generation_from_rsa_key() {
|
||||
let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
|
||||
let config = create_test_rsa_config();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
|
||||
assert_eq!(
|
||||
table_data.root_key_hash,
|
||||
[
|
||||
@@ -343,8 +415,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_generation_from_constant_key() {
|
||||
let config = super::VrfKeyConfig::ConstantVrfKey;
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap();
|
||||
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
|
||||
|
||||
assert_eq!(table_data.root_key_hash, vec![]);
|
||||
assert_eq!(table_data.enc_sym_key, None);
|
||||
@@ -360,19 +432,19 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
key: "not!valid@base64#".to_string(),
|
||||
};
|
||||
|
||||
let result = super::VrfKeyTableData::new(config.clone()).await;
|
||||
let result = super::VrfKeyTableData::new(&config.clone()).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_invalid_base64_during_retrieval() {
|
||||
let config_valid = create_test_symmetric_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config_valid).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config_valid).await.unwrap();
|
||||
|
||||
let config_invalid = super::VrfKeyConfig::B64EncodedSymmetricKey {
|
||||
key: "not!valid@base64#".to_string(),
|
||||
};
|
||||
let result = table_data.to_vrf_key(config_invalid).await;
|
||||
let result = table_data.to_vrf_key(&config_invalid).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
@@ -381,7 +453,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
let short_key = bitwarden_encoding::B64::from(vec![0u8; 16]).to_string();
|
||||
let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: short_key };
|
||||
|
||||
let result = super::VrfKeyTableData::new(config).await;
|
||||
let result = super::VrfKeyTableData::new(&config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
@@ -393,7 +465,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
private_key: malformed_pem.to_string(),
|
||||
};
|
||||
|
||||
let result = super::VrfKeyTableData::new(config).await;
|
||||
let result = super::VrfKeyTableData::new(&config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
@@ -406,62 +478,62 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
private_key: missing_headers.to_string(),
|
||||
};
|
||||
|
||||
let result = super::VrfKeyTableData::new(config).await;
|
||||
let result = super::VrfKeyTableData::new(&config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_wrong_symmetric_key_decryption() {
|
||||
let config1 = create_test_symmetric_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap();
|
||||
|
||||
let config2 = super::VrfKeyConfig::B64EncodedSymmetricKey {
|
||||
key: generate_random_symmetric_key_b64(),
|
||||
};
|
||||
|
||||
let result = table_data.to_vrf_key(config2).await;
|
||||
let result = table_data.to_vrf_key(&config2).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_wrong_rsa_key_decryption() {
|
||||
let config1 = create_test_rsa_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap();
|
||||
|
||||
let config2 = super::VrfKeyConfig::PEMEncodedRSAKey {
|
||||
private_key: generate_random_rsa_key_pem(),
|
||||
};
|
||||
|
||||
let result = table_data.to_vrf_key(config2).await;
|
||||
let result = table_data.to_vrf_key(&config2).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_rsa_missing_enc_sym_key() {
|
||||
let config = create_test_rsa_config();
|
||||
let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
|
||||
let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
table_data.enc_sym_key = None;
|
||||
|
||||
let result = table_data.to_vrf_key(config).await;
|
||||
let result = table_data.to_vrf_key(&config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_wrong_nonce() {
|
||||
let config = create_test_symmetric_config();
|
||||
let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
|
||||
let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
table_data.sym_enc_vrf_key_nonce.truncate(10);
|
||||
|
||||
let result = table_data.to_vrf_key(config).await;
|
||||
let result = table_data.to_vrf_key(&config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_nonce_size_validation() {
|
||||
let config = create_test_symmetric_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
assert_eq!(table_data.sym_enc_vrf_key_nonce.len(), 24);
|
||||
}
|
||||
@@ -470,7 +542,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
pub async fn test_empty_symmetric_key() {
|
||||
let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: String::new() };
|
||||
|
||||
let result = super::VrfKeyTableData::new(config).await;
|
||||
let result = super::VrfKeyTableData::new(&config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -480,7 +552,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
private_key: String::new(),
|
||||
};
|
||||
|
||||
let result = super::VrfKeyTableData::new(config).await;
|
||||
let result = super::VrfKeyTableData::new(&config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
@@ -489,8 +561,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
let config1 = create_test_symmetric_config();
|
||||
let config2 = create_test_symmetric_config();
|
||||
|
||||
let (table_data1, vrf_key1) = super::VrfKeyTableData::new(config1).await.unwrap();
|
||||
let (table_data2, vrf_key2) = super::VrfKeyTableData::new(config2).await.unwrap();
|
||||
let (table_data1, vrf_key1) = super::VrfKeyTableData::new(&config1).await.unwrap();
|
||||
let (table_data2, vrf_key2) = super::VrfKeyTableData::new(&config2).await.unwrap();
|
||||
|
||||
assert_ne!(vrf_key1.0, vrf_key2.0);
|
||||
assert_ne!(
|
||||
@@ -502,7 +574,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_symmetric_key_not_persisted() {
|
||||
let config = create_test_symmetric_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
let symmetric_key_bytes =
|
||||
bitwarden_encoding::B64::from_str(TEST_SYMMETRIC_KEY_B64).unwrap();
|
||||
@@ -516,7 +588,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_rsa_private_key_not_persisted() {
|
||||
let config = create_test_rsa_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
|
||||
let rsa_der = rsa_key.to_pkcs1_der().unwrap();
|
||||
@@ -532,7 +604,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_vrf_key_is_encrypted_at_rest() {
|
||||
let config = create_test_symmetric_config();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(config).await.unwrap();
|
||||
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
assert!(!table_data
|
||||
.sym_enc_vrf_key
|
||||
@@ -543,7 +615,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_symmetric_key_encryption_in_rsa_mode() {
|
||||
let config = create_test_rsa_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
|
||||
|
||||
let enc_sym_key = table_data.enc_sym_key.as_ref().unwrap();
|
||||
let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
|
||||
@@ -556,22 +628,22 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
|
||||
#[tokio::test]
|
||||
pub async fn test_cannot_decrypt_symmetric_with_rsa_config() {
|
||||
let sym_config = create_test_symmetric_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(sym_config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&sym_config).await.unwrap();
|
||||
|
||||
let rsa_config = create_test_rsa_config();
|
||||
|
||||
let result = table_data.to_vrf_key(rsa_config).await;
|
||||
let result = table_data.to_vrf_key(&rsa_config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_cannot_decrypt_rsa_with_symmetric_config() {
|
||||
let rsa_config = create_test_rsa_config();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(rsa_config).await.unwrap();
|
||||
let (table_data, _) = super::VrfKeyTableData::new(&rsa_config).await.unwrap();
|
||||
|
||||
let sym_config = create_test_symmetric_config();
|
||||
|
||||
let result = table_data.to_vrf_key(sym_config).await;
|
||||
let result = table_data.to_vrf_key(&sym_config).await;
|
||||
assert!(matches!(result, Err(super::VrfKeyCreationError)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use akd::ecvrf::HardCodedAkdVRF;
|
||||
use akd::storage::StorageManager;
|
||||
use akd::Directory;
|
||||
use akd_storage::db_config::DbConfig;
|
||||
use akd_storage::DatabaseType;
|
||||
use akd_storage::akd_storage_config::AkdStorageConfig;
|
||||
use akd_storage::db_config::{DatabaseType, DbConfig};
|
||||
use akd_storage::AkdDatabase;
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use commands::Command;
|
||||
@@ -159,33 +158,37 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Create database connection
|
||||
info!("Connecting to MS SQL database");
|
||||
let config = DbConfig::MsSql {
|
||||
connection_string,
|
||||
pool_size: args.pool_size,
|
||||
let config = AkdStorageConfig {
|
||||
db_config: DbConfig::MsSql {
|
||||
connection_string: connection_string.clone(),
|
||||
pool_size: args.pool_size,
|
||||
},
|
||||
cache_item_lifetime_ms: 30000,
|
||||
cache_limit_bytes: None,
|
||||
cache_clean_ms: 15000,
|
||||
vrf_key_config: akd_storage::vrf_key_config::VrfKeyConfig::B64EncodedSymmetricKey {
|
||||
key: "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=".to_string(),
|
||||
},
|
||||
};
|
||||
let db = config
|
||||
.connect()
|
||||
let (storage_manager, state) = config
|
||||
.initialize_storage()
|
||||
.await
|
||||
.context("Failed to connect to database")?;
|
||||
.context("Failed to initialize storage")?;
|
||||
|
||||
// Handle pre-processing modes
|
||||
if let Some(()) = pre_process_mode(&args, &db).await? {
|
||||
if let Some(()) = pre_process_mode(&args, &state.db()).await? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let storage_manager = StorageManager::new(db.clone(), None, None, None);
|
||||
let vrf = HardCodedAkdVRF {};
|
||||
let mut directory = Directory::<TC, _, _>::new(storage_manager.clone(), vrf)
|
||||
let mut directory = Directory::<TC, _, _>::new(storage_manager, state.vrf_key_database())
|
||||
.await
|
||||
.context("Failed to create AKD directory")?;
|
||||
|
||||
let (tx, mut rx) = channel(2);
|
||||
|
||||
tokio::spawn(async move {
|
||||
directory_host::init_host::<TC, _, HardCodedAkdVRF>(&mut rx, &mut directory).await
|
||||
});
|
||||
tokio::spawn(async move { directory_host::init_host(&mut rx, &mut directory).await });
|
||||
|
||||
process_mode(&args, &tx, &db).await?;
|
||||
process_mode(&args, &tx, &state).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -221,7 +224,7 @@ async fn pre_process_mode(args: &CliArgs, db: &DatabaseType) -> Result<Option<()
|
||||
async fn process_mode(
|
||||
args: &CliArgs,
|
||||
tx: &Sender<directory_host::Rpc>,
|
||||
db: &DatabaseType,
|
||||
db: &AkdDatabase,
|
||||
) -> Result<()> {
|
||||
if let Some(mode) = &args.mode {
|
||||
match mode {
|
||||
@@ -260,7 +263,7 @@ async fn process_mode(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn bench_db_insert(num_users: u64, db: &DatabaseType) -> Result<()> {
|
||||
async fn bench_db_insert(num_users: u64, db: &AkdDatabase) -> Result<()> {
|
||||
use owo_colors::OwoColorize;
|
||||
|
||||
println!("{}", "======= Benchmark operation requested =======".cyan());
|
||||
@@ -488,7 +491,7 @@ async fn bench_lookup(
|
||||
async fn repl_loop(
|
||||
_args: &CliArgs,
|
||||
tx: &Sender<directory_host::Rpc>,
|
||||
db: &DatabaseType,
|
||||
db: &AkdDatabase,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
println!("Please enter a command");
|
||||
@@ -498,7 +501,7 @@ async fn repl_loop(
|
||||
let mut line = String::new();
|
||||
stdin().read_line(&mut line)?;
|
||||
|
||||
match (db, Command::parse(&mut line)) {
|
||||
match (db.db(), Command::parse(&mut line)) {
|
||||
(_, Command::Unknown(other)) => {
|
||||
println!("Input '{other}' is not supported, enter 'help' for the help menu")
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
use thiserror::Error;
|
||||
|
||||
mod storage_manager_config;
|
||||
|
||||
pub use akd_storage::vrf_key_config::VrfKeyConfig;
|
||||
pub use storage_manager_config::*;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ConfigError {
|
||||
#[error("Failed to connect to database")]
|
||||
DatabaseConnection(#[source] akd::errors::StorageError),
|
||||
|
||||
#[error("Configuration value 'cache_item_lifetime_ms' is invalid: value {value} exceeds maximum allowed ({max})")]
|
||||
CacheLifetimeOutOfRange {
|
||||
value: usize,
|
||||
max: u64,
|
||||
#[source]
|
||||
source: std::num::TryFromIntError,
|
||||
},
|
||||
|
||||
#[error("Configuration value 'cache_clean_frequency_ms' is invalid: value {value} exceeds maximum allowed ({max})")]
|
||||
CacheCleanFrequencyOutOfRange {
|
||||
value: usize,
|
||||
max: u64,
|
||||
#[source]
|
||||
source: std::num::TryFromIntError,
|
||||
},
|
||||
|
||||
#[error("{0}")]
|
||||
Custom(String),
|
||||
|
||||
#[error("Invalid hex string for VRF key material")]
|
||||
InvalidVrfKeyMaterialHex(#[source] hex::FromHexError),
|
||||
|
||||
#[error("VRF key material must be exactly 32 bytes, got {actual} bytes")]
|
||||
VrfKeyMaterialInvalidLength { actual: usize },
|
||||
}
|
||||
|
||||
impl ConfigError {
|
||||
pub fn new(message: impl Into<String>) -> Self {
|
||||
Self::Custom(message.into())
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::ConfigError;
|
||||
use akd::storage::StorageManager;
|
||||
use akd_storage::{db_config::DbConfig, DatabaseType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// items live for 30s by default
|
||||
pub const DEFAULT_ITEM_LIFETIME_MS: usize = 30_000;
|
||||
/// clean the cache every 15s by default
|
||||
pub const DEFAULT_CACHE_CLEAN_FREQUENCY_MS: usize = 15_000;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StorageManagerConfig {
|
||||
pub db_config: DbConfig,
|
||||
pub cache_limit_bytes: Option<usize>,
|
||||
#[serde(default = "default_cache_item_lifetime_ms")]
|
||||
pub cache_item_lifetime_ms: usize,
|
||||
#[serde(default = "default_cache_clean_frequency_ms")]
|
||||
pub cache_clean_frequency_ms: usize,
|
||||
}
|
||||
|
||||
fn default_cache_item_lifetime_ms() -> usize {
|
||||
DEFAULT_ITEM_LIFETIME_MS
|
||||
}
|
||||
|
||||
fn default_cache_clean_frequency_ms() -> usize {
|
||||
DEFAULT_CACHE_CLEAN_FREQUENCY_MS
|
||||
}
|
||||
|
||||
impl StorageManagerConfig {
|
||||
pub async fn create(&self) -> Result<StorageManager<DatabaseType>, ConfigError> {
|
||||
Ok(StorageManager::new(
|
||||
self.db_config
|
||||
.connect()
|
||||
.await
|
||||
.map_err(ConfigError::DatabaseConnection)?,
|
||||
Some(Duration::from_millis(
|
||||
self.cache_item_lifetime_ms.try_into().map_err(|source| {
|
||||
ConfigError::CacheLifetimeOutOfRange {
|
||||
value: self.cache_item_lifetime_ms,
|
||||
max: u64::MAX,
|
||||
source,
|
||||
}
|
||||
})?,
|
||||
)),
|
||||
self.cache_limit_bytes,
|
||||
Some(Duration::from_millis(
|
||||
self.cache_clean_frequency_ms.try_into().map_err(|source| {
|
||||
ConfigError::CacheCleanFrequencyOutOfRange {
|
||||
value: self.cache_clean_frequency_ms,
|
||||
max: u64::MAX,
|
||||
source,
|
||||
}
|
||||
})?,
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
pub mod config;
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use akd::{directory::Directory, storage::StorageManager};
|
||||
use akd_storage::DatabaseType;
|
||||
use bitwarden_akd_configuration::BitwardenV1Configuration;
|
||||
use common::VrfStorageType;
|
||||
use tracing::instrument;
|
||||
|
||||
struct AppState {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use akd::{directory::ReadOnlyDirectory, storage::StorageManager};
|
||||
use akd_storage::DatabaseType;
|
||||
use bitwarden_akd_configuration::BitwardenV1Configuration;
|
||||
use common::VrfStorageType;
|
||||
use tracing::instrument;
|
||||
|
||||
struct AppState {
|
||||
@@ -10,7 +9,7 @@ struct AppState {
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "reader_start")]
|
||||
pub async fn start(db: DatabaseType, vrf: VrfStorageType) {
|
||||
pub async fn start(db: DatabaseType) {
|
||||
let storage_manager = StorageManager::new_no_cache(db);
|
||||
let _app = AppState {
|
||||
_directory: ReadOnlyDirectory::new(storage_manager, vrf)
|
||||
|
||||
Reference in New Issue
Block a user