1
0
mirror of https://github.com/bitwarden/server synced 2026-02-04 10:43:53 +00:00

Clean up akd storage interface for application usage

This commit is contained in:
Matt Gibson
2025-12-18 16:42:27 -08:00
parent 7eda815adb
commit 929527a5c2
20 changed files with 588 additions and 314 deletions

1
akd/Cargo.lock generated
View File

@@ -94,6 +94,7 @@ dependencies = [
"blake3",
"chacha20poly1305",
"ed25519-dalek",
"hex",
"ms_database",
"rand 0.8.5",
"rsa",

View File

@@ -21,6 +21,7 @@ bitwarden-akd-configuration = { path = "crates/bitwarden-akd-configuration" }
blake3 = "1.8.2"
common = { path = "crates/common" }
config = "0.15.18"
hex = "0.4.3"
serde = { version = "1.0.228", features = ["derive"] }
tokio = { version = "1.47.1", features = ["full"] }
tracing = { version = "0.1.41" }

View File

@@ -14,6 +14,7 @@ bitwarden-encoding = { path = "../bitwarden-encoding" }
blake3.workspace = true
chacha20poly1305 = { version = "0.10.1" }
ed25519-dalek = { version = ">=2.1.1, <=2.2.0", features = ["rand_core"] }
hex.workspace = true
ms_database = { path = "../ms_database" }
rand = ">=0.8.5, <0.9"
rsa = { version = ">=0.9.2, <0.10" }

View File

@@ -1 +1 @@
DROP TABLE IF EXISTS dbo.vrf_key;
DROP TABLE IF EXISTS dbo.akd_vrf_keys;

View File

@@ -1,11 +1,11 @@
IF OBJECT_ID('dbo.vrf_key', 'U') IS NULL
IF OBJECT_ID('dbo.akd_vrf_keys', 'U') IS NULL
BEGIN
CREATE TABLE dbo.vrf_key (
CREATE TABLE dbo.akd_vrf_keys (
root_key_hash VARBINARY(32) NOT NULL,
root_key_type INT NOT NULL,
enc_sym_key VARBINARY(32) NULL,
enc_sym_key_nonce VARBINARY(24) NULL,
sym_enc_vrf_key VARBINARY(32) NOT NULL,
root_key_type SMALLINT NOT NULL,
enc_sym_key VARBINARY(max) NULL,
sym_enc_vrf_key VARBINARY(48) NOT NULL,
sym_enc_vrf_key_nonce VARBINARY(24) NULL,
PRIMARY KEY (root_key_hash, root_key_type)
);
END

View File

@@ -0,0 +1,101 @@
use std::collections::HashMap;
use akd::{
errors::StorageError,
storage::{
types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag},
Database, DbSetState, Storable,
},
AkdLabel, AkdValue,
};
use async_trait::async_trait;
use crate::{
db_config::DatabaseType, vrf_key_config::VrfKeyConfig, vrf_key_database::VrfKeyDatabase,
};
#[derive(Debug, Clone)]
pub struct AkdDatabase {
db: DatabaseType,
vrf_key_config: VrfKeyConfig,
}
impl AkdDatabase {
pub fn db(&self) -> &DatabaseType {
&self.db
}
pub fn new(db: DatabaseType, vrf_key_config: VrfKeyConfig) -> AkdDatabase {
AkdDatabase { db, vrf_key_config }
}
pub fn vrf_key_database(&self) -> VrfKeyDatabase {
VrfKeyDatabase::new(self.db.clone(), self.vrf_key_config.clone())
}
}
#[async_trait]
impl Database for AkdDatabase {
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.set(record).await,
}
}
async fn batch_set(
&self,
records: Vec<DbRecord>,
state: DbSetState, // TODO: unused in mysql example, but may be needed later
) -> Result<(), StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.batch_set(records, state).await,
}
}
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.get::<St>(id).await,
}
}
async fn batch_get<St: Storable>(
&self,
ids: &[St::StorageKey],
) -> Result<Vec<DbRecord>, StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.batch_get::<St>(ids).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<KeyData, StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.get_user_data(raw_label).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state(
&self,
raw_label: &AkdLabel,
flag: ValueStateRetrievalFlag,
) -> Result<ValueState, StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state_versions(
&self,
raw_labels: &[AkdLabel],
flag: ValueStateRetrievalFlag,
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await,
}
}
}

View File

@@ -1,19 +1,65 @@
use serde::Deserialize;
use std::time::Duration;
use crate::db_config::DbConfig;
use akd::storage::StorageManager;
use serde::Deserialize;
use thiserror::Error;
use tracing::error;
use crate::{db_config::DbConfig, vrf_key_config::VrfKeyConfig, AkdDatabase};
#[derive(Debug, Clone, Deserialize)]
pub struct AkdStorageConfig {
_db_config: DbConfig,
pub db_config: DbConfig,
/// Controls how long items stay in cache before being evicted (in milliseconds). Defaults to 30 seconds.
#[serde(default = "default_cache_item_lifetime_ms")]
_cache_item_lifetime_ms: usize,
pub cache_item_lifetime_ms: usize,
/// Controls the maximum size of the cache in bytes. Defaults to no limit.
#[serde(default)]
_cache_limit_bytes: Option<usize>,
pub cache_limit_bytes: Option<usize>,
/// Controls how often the cache is cleaned (in milliseconds). Defaults to 15 seconds.
#[serde(default = "default_cache_clean_ms")]
_cache_clean_ms: usize,
pub cache_clean_ms: usize,
pub vrf_key_config: VrfKeyConfig,
}
#[derive(Debug, Error)]
#[error("Failed to initialize storage")]
pub struct AkdStorageInitializationError;
impl AkdStorageConfig {
pub async fn initialize_storage(
&self,
) -> Result<(StorageManager<AkdDatabase>, AkdDatabase), AkdStorageInitializationError> {
let db = self.db_config.connect().await.map_err(|err| {
error!(%err, "Failed to connect to database");
AkdStorageInitializationError
})?;
let state = AkdDatabase::new(db, self.vrf_key_config.clone());
let cache_item_lifetime = Some(Duration::from_millis(
self.cache_item_lifetime_ms.try_into().map_err(|err| {
error!(%err, "Cache item lifetime out of range");
AkdStorageInitializationError
})?,
));
let cache_clean_frequency = Some(Duration::from_millis(
self.cache_clean_ms.try_into().map_err(|err| {
error!(%err, "Cache clean interval out of range");
AkdStorageInitializationError
})?,
));
Ok((
StorageManager::new(
state.clone(),
cache_item_lifetime,
self.cache_limit_bytes,
cache_clean_frequency,
),
state,
))
}
}
fn default_cache_item_lifetime_ms() -> usize {

View File

@@ -1,7 +1,7 @@
use akd::errors::StorageError;
use serde::{Deserialize, Serialize};
use crate::DatabaseType;
use crate::ms_sql::MsSql;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
@@ -12,6 +12,13 @@ pub enum DbConfig {
},
}
/// Enum to represent different database types supported by the storage layer.
/// Each variant is cheap to clone for reuse across threads.
#[derive(Debug, Clone)]
pub enum DatabaseType {
MsSql(MsSql),
}
impl DbConfig {
pub async fn connect(&self) -> Result<DatabaseType, StorageError> {
let db = match self {

View File

@@ -1,92 +1,8 @@
use std::collections::HashMap;
use akd::{
errors::StorageError,
storage::{
types::{DbRecord, KeyData, ValueState, ValueStateRetrievalFlag},
Database, DbSetState, Storable,
},
AkdLabel, AkdValue,
};
use async_trait::async_trait;
use crate::ms_sql::MsSql;
mod akd_database;
pub mod akd_storage_config;
pub mod db_config;
pub mod ms_sql;
pub mod vrf_key_config;
pub mod vrf_key_database;
/// Enum to represent different database types supported by the storage layer.
/// Each variant is cheap to clone for reuse across threads.
#[derive(Debug, Clone)]
pub enum DatabaseType {
MsSql(MsSql),
}
#[async_trait]
impl Database for DatabaseType {
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
match self {
DatabaseType::MsSql(db) => db.set(record).await,
}
}
async fn batch_set(
&self,
records: Vec<DbRecord>,
state: DbSetState, // TODO: unused in mysql example, but may be needed later
) -> Result<(), StorageError> {
match self {
DatabaseType::MsSql(db) => db.batch_set(records, state).await,
}
}
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
match self {
DatabaseType::MsSql(db) => db.get::<St>(id).await,
}
}
async fn batch_get<St: Storable>(
&self,
ids: &[St::StorageKey],
) -> Result<Vec<DbRecord>, StorageError> {
match self {
DatabaseType::MsSql(db) => db.batch_get::<St>(ids).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<KeyData, StorageError> {
match self {
DatabaseType::MsSql(db) => db.get_user_data(raw_label).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state(
&self,
raw_label: &AkdLabel,
flag: ValueStateRetrievalFlag,
) -> Result<ValueState, StorageError> {
match self {
DatabaseType::MsSql(db) => db.get_user_state(raw_label, flag).await,
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state_versions(
&self,
raw_labels: &[AkdLabel],
flag: ValueStateRetrievalFlag,
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
match self {
DatabaseType::MsSql(db) => db.get_user_state_versions(raw_labels, flag).await,
}
}
}
pub use akd_database::*;

View File

@@ -4,6 +4,7 @@ use ms_database::{load_migrations, Migration};
pub const TABLE_AZKS: &str = "akd_azks";
pub const TABLE_HISTORY_TREE_NODES: &str = "akd_history_tree_nodes";
pub const TABLE_VALUES: &str = "akd_values";
pub const TABLE_VRF_KEYS: &str = "akd_vrf_keys";
pub const TABLE_MIGRATIONS: &str = ms_database::TABLE_MIGRATIONS;
pub(crate) const MIGRATIONS: &[Migration] = load_migrations!("migrations/ms_sql");

View File

@@ -18,6 +18,7 @@ use tracing::{debug, error, info, instrument, trace, warn};
use migrations::{
MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES,
TABLE_VRF_KEYS,
};
use tables::{
akd_storable_for_ms_sql::{AkdStorableForMsSql, Statement},
@@ -25,6 +26,12 @@ use tables::{
values,
};
use crate::{
ms_sql::tables::vrf_key,
vrf_key_config::VrfKeyConfig,
vrf_key_database::{VrfKeyRetrievalError, VrfKeyStorageError, VrfKeyTableData},
};
const DEFAULT_POOL_SIZE: u32 = 100;
pub struct MsSqlBuilder {
@@ -116,6 +123,7 @@ impl MsSql {
DROP TABLE IF EXISTS {TABLE_AZKS};
DROP TABLE IF EXISTS {TABLE_HISTORY_TREE_NODES};
DROP TABLE IF EXISTS {TABLE_VALUES};
DROP TABLE IF EXISTS {TABLE_VRF_KEYS};
DROP TABLE IF EXISTS {TABLE_MIGRATIONS};"#
);
@@ -164,6 +172,66 @@ impl MsSql {
}
}
impl MsSql {
#[instrument(skip(self, config), level = "debug")]
pub async fn get_vrf_key(
&self,
config: &VrfKeyConfig,
) -> Result<VrfKeyTableData, VrfKeyRetrievalError> {
debug!("Retrieving VRF key from database");
let mut conn = self.get_connection().await.map_err(|err| {
error!(%err, "Failed to get DB connection for VRF key retrieval");
VrfKeyRetrievalError::DatabaseError
})?;
let statement = vrf_key::get_statement(&config).map_err(|err| {
error!(%err, "Failed to build VRF key retrieval statement");
VrfKeyRetrievalError::CorruptedData
})?;
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|err| {
error!(%err, "Failed to execute VRF key retrieval query");
VrfKeyRetrievalError::DatabaseError
})?;
let row = query_stream.into_row().await.map_err(|err| {
error!(%err, "Failed to fetch VRF key row");
VrfKeyRetrievalError::DatabaseError
})?;
match row {
None => {
debug!("VRF key not found");
return Err(VrfKeyRetrievalError::KeyNotFound);
}
Some(row) => {
debug!("VRF key found");
vrf_key::from_row(&row).map_err(|err| {
error!(%err, "Failed to parse VRF key from database row");
VrfKeyRetrievalError::CorruptedData
})
}
}
}
#[instrument(skip(self, table_data), level = "debug")]
pub async fn store_vrf_key(
&self,
table_data: &VrfKeyTableData,
) -> Result<(), VrfKeyStorageError> {
debug!("Storing VRF key in database");
let statement = vrf_key::store_statement(table_data);
self.execute_statement(&statement).await.map_err(|err| {
error!(%err, "Failed to store VRF key in database");
VrfKeyStorageError
})
}
}
#[async_trait]
impl Database for MsSql {
#[instrument(skip(self, record), level = "debug")]
@@ -552,7 +620,9 @@ impl Database for MsSql {
statement.parse(&row)
} else {
debug!("Raw label not found");
Err(StorageError::NotFound("ValueState for label not found".to_string()))
Err(StorageError::NotFound(
"ValueState for label not found".to_string(),
))
}
}

View File

@@ -1,6 +1,101 @@
use crate::vrf_key_database::VrfKeyTableData;
use akd::errors::StorageError;
use tracing::debug;
#[allow(unused)]
pub async fn get_vrf_key(root_key: &[u8]) -> VrfKeyTableData {
todo!()
use crate::{
ms_sql::{
migrations::TABLE_VRF_KEYS,
sql_params::SqlParams,
tables::akd_storable_for_ms_sql::{QueryStatement, Statement},
},
vrf_key_config::{VrfKeyConfig, VrfRootKeyError},
vrf_key_database::VrfKeyTableData,
};
pub fn get_statement(
config: &VrfKeyConfig,
) -> Result<QueryStatement<VrfKeyTableData>, VrfRootKeyError> {
debug!("Building get_statement for vrf key");
let mut params = SqlParams::new();
params.add("root_key_type", Box::new(config.root_key_type() as i16));
params.add(
"root_key_hash",
Box::new(config.root_key_hash().expect("valid root key hash")),
);
let sql = format!(
r#"
SELECT root_key_hash, root_key_type, enc_sym_key, sym_enc_vrf_key, sym_enc_vrf_key_nonce
FROM {}
WHERE root_key_type = {} AND root_key_hash = {}"#,
TABLE_VRF_KEYS,
params
.key_for("root_key_type")
.expect("root_key_type was added to the params list"),
params
.key_for("root_key_hash")
.expect("root_key_hash was added to the params list"),
);
Ok(QueryStatement::new(sql, params, from_row))
}
pub fn from_row(row: &ms_database::Row) -> Result<VrfKeyTableData, StorageError> {
let root_key_hash: &[u8] = row
.get("root_key_hash")
.ok_or_else(|| StorageError::Other("Missing root_key_hash column".to_string()))?;
let root_key_type: i16 = row
.get("root_key_type")
.ok_or_else(|| StorageError::Other("root_key_type is NULL or missing".to_string()))?;
let enc_sym_key: Option<&[u8]> = row.get("enc_sym_key");
let sym_enc_vrf_key: &[u8] = row
.get("sym_enc_vrf_key")
.ok_or_else(|| StorageError::Other("sym_enc_vrf_key is NULL or missing".to_string()))?;
let sym_enc_vrf_key_nonce: &[u8] = row.get("sym_enc_vrf_key_nonce").ok_or_else(|| {
StorageError::Other("sym_enc_vrf_key_nonce is NULL or missing".to_string())
})?;
Ok(VrfKeyTableData {
root_key_hash: root_key_hash.to_vec(),
root_key_type: root_key_type.into(),
enc_sym_key: enc_sym_key.map(|k| k.to_vec()),
sym_enc_vrf_key: sym_enc_vrf_key.to_vec(),
sym_enc_vrf_key_nonce: sym_enc_vrf_key_nonce.to_vec(),
})
}
pub fn store_statement(table_data: &VrfKeyTableData) -> Statement {
debug!("Building store_statement for vrf key");
let mut params = SqlParams::new();
params.add("root_key_hash", Box::new(table_data.root_key_hash.clone()));
params.add(
"root_key_type",
Box::new(Into::<i16>::into(table_data.root_key_type)),
);
params.add("enc_sym_key", Box::new(table_data.enc_sym_key.clone()));
params.add(
"sym_enc_vrf_key",
Box::new(table_data.sym_enc_vrf_key.clone()),
);
params.add(
"sym_enc_vrf_key_nonce",
Box::new(table_data.sym_enc_vrf_key_nonce.clone()),
);
let sql = format!(
r#"
MERGE INTO dbo.{TABLE_VRF_KEYS} AS t
USING (SELECT {}) AS source
ON t.root_key_hash = source.root_key_hash AND t.root_key_type = source.root_key_type
WHEN MATCHED THEN
UPDATE SET {}
WHEN NOT MATCHED THEN
INSERT ({})
VALUES ({});"#,
params.keys_as_columns().join(", "),
params
.set_columns_equal_except("t.", "source.", vec!["root_key_hash", "root_key_type"])
.join(", "),
params.columns().join(", "),
params.keys().join(", "),
);
Statement::new(sql, params)
}

View File

@@ -1,4 +1,11 @@
use std::str::FromStr;
use rsa::pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::error;
use crate::vrf_key_database::VrfRootKeyType;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
@@ -36,3 +43,60 @@ pub enum VrfKeyConfig {
/// Losing this key is equivalent to losing your directory's VRF key.
PEMEncodedRSAKey { private_key: String },
}
#[derive(Debug, Error)]
#[error("Error reading root key from configuration")]
pub struct VrfRootKeyError;
impl VrfKeyConfig {
pub fn root_key_bytes(&self) -> Result<Vec<u8>, VrfRootKeyError> {
match self {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => {
// This is the hard coded vrf key
Ok(
hex::decode("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721")
.map_err(|err| {
error!("Failed to decode hardcoded vrf key: {}", err);
VrfRootKeyError
})?,
)
}
VrfKeyConfig::B64EncodedSymmetricKey { key } => bitwarden_encoding::B64::from_str(&key)
.map_err(|err| {
error!(%err, "Failed to decode symmetric key from base64 format");
VrfRootKeyError
})
.map(|b64| Vec::<u8>::from(b64)),
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
let rsa_private_key =
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
error!(%err, "Failed to decode RSA private key from PEM format");
VrfRootKeyError
})?;
Ok(rsa_private_key
.to_pkcs1_der()
.map_err(|err| {
error!(%err, "Failed to encode RSA private key to DER format");
VrfRootKeyError
})?
.as_bytes()
.to_vec())
}
}
}
pub fn root_key_hash(&self) -> Result<Vec<u8>, VrfRootKeyError> {
let root_key_bytes = self.root_key_bytes()?;
Ok(blake3::hash(&root_key_bytes).as_bytes().to_vec())
}
pub fn root_key_type(&self) -> VrfRootKeyType {
match self {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => VrfRootKeyType::None,
VrfKeyConfig::B64EncodedSymmetricKey { .. } => VrfRootKeyType::SymmetricKey,
VrfKeyConfig::PEMEncodedRSAKey { .. } => VrfRootKeyType::RsaKey,
}
}
}

View File

@@ -1,17 +1,14 @@
use std::str::FromStr;
use akd::ecvrf::{VRFKeyStorage, VrfError};
use async_trait::async_trait;
use chacha20poly1305::{
aead::{generic_array::GenericArray, Aead},
AeadCore, KeyInit, XChaCha20Poly1305,
};
use rsa::{
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
Pkcs1v15Encrypt,
};
use rsa::{pkcs1::DecodeRsaPrivateKey, Pkcs1v15Encrypt};
use thiserror::Error;
use tracing::{error, info};
use crate::vrf_key_config::VrfKeyConfig;
use crate::{db_config::DatabaseType, vrf_key_config::VrfKeyConfig};
/// Represents a storage-layer error
#[derive(Debug, Error)]
@@ -35,31 +32,116 @@ pub struct VrfKeyCreationError;
#[error("Internal VRF key storage error")]
pub struct VrfKeyStorageError;
#[allow(unused)]
trait VrfKeyTable {
async fn get_vrf_key(config: VrfKeyConfig) -> Result<VrfKeyTableData, VrfKeyRetrievalError>;
async fn store_vrf_key(table_data: VrfKeyTableData) -> Result<(), VrfKeyStorageError>;
#[derive(Debug, Clone)]
pub struct VrfKeyDatabase {
db: DatabaseType,
vrf_key_config: VrfKeyConfig,
cached_vrf_key: Option<Vec<u8>>,
}
impl VrfKeyDatabase {
pub fn new(db: DatabaseType, config: VrfKeyConfig) -> VrfKeyDatabase {
VrfKeyDatabase {
db,
vrf_key_config: config,
cached_vrf_key: None,
}
}
async fn get_vrf_key(&self) -> Result<VrfKeyTableData, VrfKeyRetrievalError> {
match &self.db {
DatabaseType::MsSql(db) => db.get_vrf_key(&self.vrf_key_config).await,
}
}
async fn store_vrf_key(&self, table_data: &VrfKeyTableData) -> Result<(), VrfKeyStorageError> {
match &self.db {
DatabaseType::MsSql(db) => db.store_vrf_key(table_data).await,
}
}
}
#[async_trait]
impl VRFKeyStorage for VrfKeyDatabase {
async fn retrieve(&self) -> Result<Vec<u8>, VrfError> {
if let Some(cached_key) = &self.cached_vrf_key {
return Ok(cached_key.clone());
}
match &self.get_vrf_key().await {
Ok(table_data) => table_data
.to_vrf_key(&self.vrf_key_config)
.await
.map(|k| k.0)
.map_err(|err| {
VrfError::SigningKey(format!("Error decrypting signing key: {err}"))
}),
Err(VrfKeyRetrievalError::KeyNotFound) => {
// Make a new key
let (table_data, key) =
VrfKeyTableData::new(&self.vrf_key_config)
.await
.map_err(|err| {
VrfError::SigningKey(format!("Failed to create a new VRF key: {err}"))
})?;
// store store
self.store_vrf_key(&table_data).await.map_err(|err| {
VrfError::SigningKey(format!("Failed to store a new VRF key: {err}"))
})?;
// and return
Ok(key.0)
}
Err(err) => {
error!(%err, "Key retrieval error");
Err(VrfError::SigningKey("Key retrieval error".to_string()))
}
}
}
}
pub struct VrfKeyTableData {
pub root_key_hash: Vec<u8>,
pub root_key_type: RootKeyType,
pub root_key_type: VrfRootKeyType,
pub enc_sym_key: Option<Vec<u8>>,
pub sym_enc_vrf_key: Vec<u8>,
pub sym_enc_vrf_key_nonce: Vec<u8>,
}
pub enum RootKeyType {
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum VrfRootKeyType {
#[cfg(test)]
None = 0,
SymmetricKey = 1,
RsaKey = 2,
}
impl From<i16> for VrfRootKeyType {
fn from(value: i16) -> Self {
match value {
1 => VrfRootKeyType::SymmetricKey,
2 => VrfRootKeyType::RsaKey,
#[cfg(test)]
0 => VrfRootKeyType::None,
_ => panic!("Invalid VrfRootKeyType value: {}", value),
}
}
}
impl From<VrfRootKeyType> for i16 {
fn from(value: VrfRootKeyType) -> Self {
match value {
VrfRootKeyType::SymmetricKey => 1,
VrfRootKeyType::RsaKey => 2,
#[cfg(test)]
VrfRootKeyType::None => 0,
}
}
}
#[derive(Clone)]
pub struct VrfKey(pub Vec<u8>);
impl VrfKeyTableData {
pub async fn new(config: VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> {
pub async fn new(config: &VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyCreationError> {
info!("Generating new VRF key and table data");
// handle constant key case separately to avoid unnecessary key generation / parsing
#[cfg(test)]
@@ -68,8 +150,8 @@ impl VrfKeyTableData {
return Ok((
VrfKeyTableData {
root_key_hash: vec![],
root_key_type: RootKeyType::None,
root_key_hash: config.root_key_hash().expect("hard coded vrf key"),
root_key_type: VrfRootKeyType::None,
enc_sym_key: None,
sym_enc_vrf_key: vec![],
sym_enc_vrf_key_nonce: vec![],
@@ -78,17 +160,14 @@ impl VrfKeyTableData {
));
}
let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key } = &config {
let raw_key = bitwarden_encoding::B64::from_str(key).map_err(|err| {
error!(%err, "Invalid b64 encoding of symmetric key");
VrfKeyCreationError
})?;
let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key: _ } = &config {
let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?;
(
XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| {
XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| {
error!(%err, "Invalid symmetric key length");
VrfKeyCreationError
})?,
raw_key.as_bytes().to_vec(),
raw_key,
)
} else {
let key = XChaCha20Poly1305::generate_key(rand::thread_rng());
@@ -103,16 +182,22 @@ impl VrfKeyTableData {
VrfKeyCreationError
})?;
match config {
match &config {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => {
let root_key_hash = blake3::hash(&sym_key).as_bytes().to_vec();
let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?;
error!(
rkh = root_key_hash.len(),
sevk = sym_enc_vrf_key.len(),
sevkn = nonce.len(),
"lengths of stuff!!!!!\n\n\n\n"
);
Ok((
VrfKeyTableData {
root_key_hash,
root_key_type: RootKeyType::SymmetricKey,
root_key_type: VrfRootKeyType::SymmetricKey,
enc_sym_key: None,
sym_enc_vrf_key,
sym_enc_vrf_key_nonce: nonce.to_vec(),
@@ -122,21 +207,11 @@ impl VrfKeyTableData {
}
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
let rsa_private_key =
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| {
error!(%err, "Failed to decode RSA private key from PEM format");
VrfKeyCreationError
})?;
let root_key_hash = blake3::hash(
rsa_private_key
.to_pkcs1_der()
.map_err(|err| {
error!(%err, "Failed to encode RSA private key to DER format");
VrfKeyCreationError
})?
.as_bytes(),
)
.as_bytes()
.to_vec();
let root_key_hash = config.root_key_hash().map_err(|_| VrfKeyCreationError)?;
let rsa_public_key = rsa_private_key.to_public_key();
let enc_sym_key = rsa_public_key
.encrypt(&mut rand::thread_rng(), Pkcs1v15Encrypt, &sym_key)
@@ -148,7 +223,7 @@ impl VrfKeyTableData {
Ok((
VrfKeyTableData {
root_key_hash,
root_key_type: RootKeyType::RsaKey,
root_key_type: VrfRootKeyType::RsaKey,
enc_sym_key: Some(enc_sym_key),
sym_enc_vrf_key,
sym_enc_vrf_key_nonce: nonce.to_vec(),
@@ -159,7 +234,7 @@ impl VrfKeyTableData {
}
}
pub async fn to_vrf_key(&self, config: VrfKeyConfig) -> Result<VrfKey, VrfKeyCreationError> {
pub async fn to_vrf_key(&self, config: &VrfKeyConfig) -> Result<VrfKey, VrfKeyCreationError> {
info!("Decrypting VrfKeyTableData to obtain VRF key");
// handle constant key case separately to avoid unnecessary key generation / parsing
#[cfg(test)]
@@ -179,15 +254,12 @@ impl VrfKeyTableData {
return Err(VrfKeyCreationError);
}
let nonce = GenericArray::from_slice(self.sym_enc_vrf_key_nonce.as_ref());
let vrf_key = match config {
let vrf_key = match &config {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
VrfKeyConfig::B64EncodedSymmetricKey { key } => {
let raw_key = bitwarden_encoding::B64::from_str(&key).map_err(|err| {
error!(%err, "Invalid b64 encoding of symmetric key");
VrfKeyCreationError
})?;
let sym = XChaCha20Poly1305::new_from_slice(raw_key.as_bytes()).map_err(|err| {
VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => {
let raw_key = config.root_key_bytes().map_err(|_| VrfKeyCreationError)?;
let sym = XChaCha20Poly1305::new_from_slice(&raw_key).map_err(|err| {
error!(%err, "Invalid symmetric key length");
VrfKeyCreationError
})?;
@@ -202,7 +274,7 @@ impl VrfKeyTableData {
}
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
let rsa_private_key =
rsa::RsaPrivateKey::from_pkcs1_pem(&private_key).map_err(|err| {
rsa::RsaPrivateKey::from_pkcs1_pem(private_key).map_err(|err| {
error!(%err, "Failed to decode RSA private key from PEM format");
VrfKeyCreationError
})?;
@@ -303,8 +375,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_generation_from_symmetric_key() {
let config = create_test_symmetric_config();
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
assert_eq!(table_data.enc_sym_key, None);
assert_eq!(
@@ -322,8 +394,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
pub async fn test_generation_from_rsa_key() {
let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
let config = create_test_rsa_config();
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
assert_eq!(
table_data.root_key_hash,
[
@@ -343,8 +415,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_generation_from_constant_key() {
let config = super::VrfKeyConfig::ConstantVrfKey;
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(&config).await.unwrap();
assert_eq!(table_data.root_key_hash, vec![]);
assert_eq!(table_data.enc_sym_key, None);
@@ -360,19 +432,19 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
key: "not!valid@base64#".to_string(),
};
let result = super::VrfKeyTableData::new(config.clone()).await;
let result = super::VrfKeyTableData::new(&config.clone()).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
#[tokio::test]
pub async fn test_invalid_base64_during_retrieval() {
let config_valid = create_test_symmetric_config();
let (table_data, _) = super::VrfKeyTableData::new(config_valid).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config_valid).await.unwrap();
let config_invalid = super::VrfKeyConfig::B64EncodedSymmetricKey {
key: "not!valid@base64#".to_string(),
};
let result = table_data.to_vrf_key(config_invalid).await;
let result = table_data.to_vrf_key(&config_invalid).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
@@ -381,7 +453,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
let short_key = bitwarden_encoding::B64::from(vec![0u8; 16]).to_string();
let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: short_key };
let result = super::VrfKeyTableData::new(config).await;
let result = super::VrfKeyTableData::new(&config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
@@ -393,7 +465,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
private_key: malformed_pem.to_string(),
};
let result = super::VrfKeyTableData::new(config).await;
let result = super::VrfKeyTableData::new(&config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
@@ -406,62 +478,62 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
private_key: missing_headers.to_string(),
};
let result = super::VrfKeyTableData::new(config).await;
let result = super::VrfKeyTableData::new(&config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
#[tokio::test]
pub async fn test_wrong_symmetric_key_decryption() {
let config1 = create_test_symmetric_config();
let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap();
let config2 = super::VrfKeyConfig::B64EncodedSymmetricKey {
key: generate_random_symmetric_key_b64(),
};
let result = table_data.to_vrf_key(config2).await;
let result = table_data.to_vrf_key(&config2).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
#[tokio::test]
pub async fn test_wrong_rsa_key_decryption() {
let config1 = create_test_rsa_config();
let (table_data, _) = super::VrfKeyTableData::new(config1).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config1).await.unwrap();
let config2 = super::VrfKeyConfig::PEMEncodedRSAKey {
private_key: generate_random_rsa_key_pem(),
};
let result = table_data.to_vrf_key(config2).await;
let result = table_data.to_vrf_key(&config2).await;
assert!(result.is_err());
}
#[tokio::test]
pub async fn test_rsa_missing_enc_sym_key() {
let config = create_test_rsa_config();
let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
table_data.enc_sym_key = None;
let result = table_data.to_vrf_key(config).await;
let result = table_data.to_vrf_key(&config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
#[tokio::test]
pub async fn test_wrong_nonce() {
let config = create_test_symmetric_config();
let (mut table_data, _) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let (mut table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
table_data.sym_enc_vrf_key_nonce.truncate(10);
let result = table_data.to_vrf_key(config).await;
let result = table_data.to_vrf_key(&config).await;
assert!(result.is_err());
}
#[tokio::test]
pub async fn test_nonce_size_validation() {
let config = create_test_symmetric_config();
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
assert_eq!(table_data.sym_enc_vrf_key_nonce.len(), 24);
}
@@ -470,7 +542,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
pub async fn test_empty_symmetric_key() {
let config = super::VrfKeyConfig::B64EncodedSymmetricKey { key: String::new() };
let result = super::VrfKeyTableData::new(config).await;
let result = super::VrfKeyTableData::new(&config).await;
assert!(result.is_err());
}
@@ -480,7 +552,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
private_key: String::new(),
};
let result = super::VrfKeyTableData::new(config).await;
let result = super::VrfKeyTableData::new(&config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
@@ -489,8 +561,8 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
let config1 = create_test_symmetric_config();
let config2 = create_test_symmetric_config();
let (table_data1, vrf_key1) = super::VrfKeyTableData::new(config1).await.unwrap();
let (table_data2, vrf_key2) = super::VrfKeyTableData::new(config2).await.unwrap();
let (table_data1, vrf_key1) = super::VrfKeyTableData::new(&config1).await.unwrap();
let (table_data2, vrf_key2) = super::VrfKeyTableData::new(&config2).await.unwrap();
assert_ne!(vrf_key1.0, vrf_key2.0);
assert_ne!(
@@ -502,7 +574,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_symmetric_key_not_persisted() {
let config = create_test_symmetric_config();
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
let symmetric_key_bytes =
bitwarden_encoding::B64::from_str(TEST_SYMMETRIC_KEY_B64).unwrap();
@@ -516,7 +588,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_rsa_private_key_not_persisted() {
let config = create_test_rsa_config();
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
let rsa_der = rsa_key.to_pkcs1_der().unwrap();
@@ -532,7 +604,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_vrf_key_is_encrypted_at_rest() {
let config = create_test_symmetric_config();
let (table_data, vrf_key) = super::VrfKeyTableData::new(config).await.unwrap();
let (table_data, vrf_key) = super::VrfKeyTableData::new(&config).await.unwrap();
assert!(!table_data
.sym_enc_vrf_key
@@ -543,7 +615,7 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_symmetric_key_encryption_in_rsa_mode() {
let config = create_test_rsa_config();
let (table_data, _) = super::VrfKeyTableData::new(config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&config).await.unwrap();
let enc_sym_key = table_data.enc_sym_key.as_ref().unwrap();
let rsa_key = rsa::RsaPrivateKey::from_pkcs1_pem(TEST_RSA_PRIVATE_KEY).unwrap();
@@ -556,22 +628,22 @@ k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
#[tokio::test]
pub async fn test_cannot_decrypt_symmetric_with_rsa_config() {
let sym_config = create_test_symmetric_config();
let (table_data, _) = super::VrfKeyTableData::new(sym_config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&sym_config).await.unwrap();
let rsa_config = create_test_rsa_config();
let result = table_data.to_vrf_key(rsa_config).await;
let result = table_data.to_vrf_key(&rsa_config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
#[tokio::test]
pub async fn test_cannot_decrypt_rsa_with_symmetric_config() {
let rsa_config = create_test_rsa_config();
let (table_data, _) = super::VrfKeyTableData::new(rsa_config).await.unwrap();
let (table_data, _) = super::VrfKeyTableData::new(&rsa_config).await.unwrap();
let sym_config = create_test_symmetric_config();
let result = table_data.to_vrf_key(sym_config).await;
let result = table_data.to_vrf_key(&sym_config).await;
assert!(matches!(result, Err(super::VrfKeyCreationError)));
}
}

View File

@@ -1,8 +1,7 @@
use akd::ecvrf::HardCodedAkdVRF;
use akd::storage::StorageManager;
use akd::Directory;
use akd_storage::db_config::DbConfig;
use akd_storage::DatabaseType;
use akd_storage::akd_storage_config::AkdStorageConfig;
use akd_storage::db_config::{DatabaseType, DbConfig};
use akd_storage::AkdDatabase;
use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use commands::Command;
@@ -159,33 +158,37 @@ async fn main() -> Result<()> {
// Create database connection
info!("Connecting to MS SQL database");
let config = DbConfig::MsSql {
connection_string,
pool_size: args.pool_size,
let config = AkdStorageConfig {
db_config: DbConfig::MsSql {
connection_string: connection_string.clone(),
pool_size: args.pool_size,
},
cache_item_lifetime_ms: 30000,
cache_limit_bytes: None,
cache_clean_ms: 15000,
vrf_key_config: akd_storage::vrf_key_config::VrfKeyConfig::B64EncodedSymmetricKey {
key: "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=".to_string(),
},
};
let db = config
.connect()
let (storage_manager, state) = config
.initialize_storage()
.await
.context("Failed to connect to database")?;
.context("Failed to initialize storage")?;
// Handle pre-processing modes
if let Some(()) = pre_process_mode(&args, &db).await? {
if let Some(()) = pre_process_mode(&args, &state.db()).await? {
return Ok(());
}
let storage_manager = StorageManager::new(db.clone(), None, None, None);
let vrf = HardCodedAkdVRF {};
let mut directory = Directory::<TC, _, _>::new(storage_manager.clone(), vrf)
let mut directory = Directory::<TC, _, _>::new(storage_manager, state.vrf_key_database())
.await
.context("Failed to create AKD directory")?;
let (tx, mut rx) = channel(2);
tokio::spawn(async move {
directory_host::init_host::<TC, _, HardCodedAkdVRF>(&mut rx, &mut directory).await
});
tokio::spawn(async move { directory_host::init_host(&mut rx, &mut directory).await });
process_mode(&args, &tx, &db).await?;
process_mode(&args, &tx, &state).await?;
Ok(())
}
@@ -221,7 +224,7 @@ async fn pre_process_mode(args: &CliArgs, db: &DatabaseType) -> Result<Option<()
async fn process_mode(
args: &CliArgs,
tx: &Sender<directory_host::Rpc>,
db: &DatabaseType,
db: &AkdDatabase,
) -> Result<()> {
if let Some(mode) = &args.mode {
match mode {
@@ -260,7 +263,7 @@ async fn process_mode(
Ok(())
}
async fn bench_db_insert(num_users: u64, db: &DatabaseType) -> Result<()> {
async fn bench_db_insert(num_users: u64, db: &AkdDatabase) -> Result<()> {
use owo_colors::OwoColorize;
println!("{}", "======= Benchmark operation requested =======".cyan());
@@ -488,7 +491,7 @@ async fn bench_lookup(
async fn repl_loop(
_args: &CliArgs,
tx: &Sender<directory_host::Rpc>,
db: &DatabaseType,
db: &AkdDatabase,
) -> Result<()> {
loop {
println!("Please enter a command");
@@ -498,7 +501,7 @@ async fn repl_loop(
let mut line = String::new();
stdin().read_line(&mut line)?;
match (db, Command::parse(&mut line)) {
match (db.db(), Command::parse(&mut line)) {
(_, Command::Unknown(other)) => {
println!("Input '{other}' is not supported, enter 'help' for the help menu")
}

View File

@@ -1,43 +0,0 @@
use thiserror::Error;
mod storage_manager_config;
pub use akd_storage::vrf_key_config::VrfKeyConfig;
pub use storage_manager_config::*;
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Failed to connect to database")]
DatabaseConnection(#[source] akd::errors::StorageError),
#[error("Configuration value 'cache_item_lifetime_ms' is invalid: value {value} exceeds maximum allowed ({max})")]
CacheLifetimeOutOfRange {
value: usize,
max: u64,
#[source]
source: std::num::TryFromIntError,
},
#[error("Configuration value 'cache_clean_frequency_ms' is invalid: value {value} exceeds maximum allowed ({max})")]
CacheCleanFrequencyOutOfRange {
value: usize,
max: u64,
#[source]
source: std::num::TryFromIntError,
},
#[error("{0}")]
Custom(String),
#[error("Invalid hex string for VRF key material")]
InvalidVrfKeyMaterialHex(#[source] hex::FromHexError),
#[error("VRF key material must be exactly 32 bytes, got {actual} bytes")]
VrfKeyMaterialInvalidLength { actual: usize },
}
impl ConfigError {
pub fn new(message: impl Into<String>) -> Self {
Self::Custom(message.into())
}
}

View File

@@ -1,59 +0,0 @@
use std::time::Duration;
use crate::config::ConfigError;
use akd::storage::StorageManager;
use akd_storage::{db_config::DbConfig, DatabaseType};
use serde::{Deserialize, Serialize};
/// items live for 30s by default
pub const DEFAULT_ITEM_LIFETIME_MS: usize = 30_000;
/// clean the cache every 15s by default
pub const DEFAULT_CACHE_CLEAN_FREQUENCY_MS: usize = 15_000;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StorageManagerConfig {
pub db_config: DbConfig,
pub cache_limit_bytes: Option<usize>,
#[serde(default = "default_cache_item_lifetime_ms")]
pub cache_item_lifetime_ms: usize,
#[serde(default = "default_cache_clean_frequency_ms")]
pub cache_clean_frequency_ms: usize,
}
fn default_cache_item_lifetime_ms() -> usize {
DEFAULT_ITEM_LIFETIME_MS
}
fn default_cache_clean_frequency_ms() -> usize {
DEFAULT_CACHE_CLEAN_FREQUENCY_MS
}
impl StorageManagerConfig {
pub async fn create(&self) -> Result<StorageManager<DatabaseType>, ConfigError> {
Ok(StorageManager::new(
self.db_config
.connect()
.await
.map_err(ConfigError::DatabaseConnection)?,
Some(Duration::from_millis(
self.cache_item_lifetime_ms.try_into().map_err(|source| {
ConfigError::CacheLifetimeOutOfRange {
value: self.cache_item_lifetime_ms,
max: u64::MAX,
source,
}
})?,
)),
self.cache_limit_bytes,
Some(Duration::from_millis(
self.cache_clean_frequency_ms.try_into().map_err(|source| {
ConfigError::CacheCleanFrequencyOutOfRange {
value: self.cache_clean_frequency_ms,
max: u64::MAX,
source,
}
})?,
)),
))
}
}

View File

@@ -1 +1 @@
pub mod config;

View File

@@ -1,7 +1,6 @@
use akd::{directory::Directory, storage::StorageManager};
use akd_storage::DatabaseType;
use bitwarden_akd_configuration::BitwardenV1Configuration;
use common::VrfStorageType;
use tracing::instrument;
struct AppState {

View File

@@ -1,7 +1,6 @@
use akd::{directory::ReadOnlyDirectory, storage::StorageManager};
use akd_storage::DatabaseType;
use bitwarden_akd_configuration::BitwardenV1Configuration;
use common::VrfStorageType;
use tracing::instrument;
struct AppState {
@@ -10,7 +9,7 @@ struct AppState {
}
#[instrument(skip_all, name = "reader_start")]
pub async fn start(db: DatabaseType, vrf: VrfStorageType) {
pub async fn start(db: DatabaseType) {
let storage_manager = StorageManager::new_no_cache(db);
let _app = AppState {
_directory: ReadOnlyDirectory::new(storage_manager, vrf)