1
0
mirror of https://github.com/bitwarden/server synced 2026-02-13 15:04:03 +00:00

Vrf keys are created by the application and protected by external means

This commit is contained in:
Matt Gibson
2025-12-15 09:04:59 -08:00
parent 6323175da9
commit bc82b338a1
29 changed files with 1180 additions and 53 deletions

232
akd/Cargo.lock generated
View File

@@ -17,6 +17,16 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
@@ -79,9 +89,17 @@ version = "0.1.0"
dependencies = [
"akd",
"async-trait",
"base64 0.22.1",
"bitwarden-encoding",
"blake3",
"chacha20poly1305",
"ed25519-dalek",
"ms_database",
"rand 0.8.5",
"rsa",
"serde",
"thiserror 2.0.17",
"tokio",
"tracing",
]
@@ -380,6 +398,12 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.8.0"
@@ -423,6 +447,17 @@ dependencies = [
"uuid",
]
[[package]]
name = "bitwarden-encoding"
version = "0.1.0"
dependencies = [
"data-encoding",
"data-encoding-macro",
"serde",
"serde_json",
"thiserror 2.0.17",
]
[[package]]
name = "blake3"
version = "1.8.2"
@@ -492,6 +527,30 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9"
[[package]]
name = "chacha20"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "chacha20poly1305"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
dependencies = [
"aead",
"chacha20",
"cipher",
"poly1305",
"zeroize",
]
[[package]]
name = "chrono"
version = "0.4.42"
@@ -501,6 +560,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
"zeroize",
]
[[package]]
name = "clap"
version = "4.5.50"
@@ -681,6 +751,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
@@ -724,6 +795,32 @@ dependencies = [
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "data-encoding-macro"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47ce6c96ea0102f01122a185683611bd5ac8d99e62bc59dd12e6bda344ee673d"
dependencies = [
"data-encoding",
"data-encoding-macro-internal",
]
[[package]]
name = "data-encoding-macro-internal"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d162beedaa69905488a8da94f5ac3edb4dd4788b732fadb7bd120b2625c1976"
dependencies = [
"data-encoding",
"syn 2.0.106",
]
[[package]]
name = "der"
version = "0.7.10"
@@ -731,6 +828,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"const-oid",
"pem-rfc7468",
"zeroize",
]
@@ -741,6 +839,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"const-oid",
"crypto-common",
]
@@ -782,6 +881,7 @@ checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9"
dependencies = [
"curve25519-dalek",
"ed25519",
"rand_core 0.6.4",
"serde",
"sha2",
"signature",
@@ -1237,6 +1337,15 @@ dependencies = [
"hashbrown 0.16.0",
]
[[package]]
name = "inout"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
[[package]]
name = "io-uring"
version = "0.7.10"
@@ -1295,6 +1404,9 @@ name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
dependencies = [
"spin",
]
[[package]]
name = "libc"
@@ -1302,6 +1414,12 @@ version = "0.2.176"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
@@ -1431,6 +1549,42 @@ dependencies = [
"windows-sys 0.61.1",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7"
dependencies = [
"lazy_static",
"libm",
"num-integer",
"num-iter",
"num-traits",
"rand 0.8.5",
"smallvec",
"zeroize",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -1438,6 +1592,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@@ -1461,6 +1616,12 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl"
version = "0.10.75"
@@ -1556,6 +1717,15 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
dependencies = [
"base64ct",
]
[[package]]
name = "percent-encoding"
version = "2.3.2"
@@ -1628,6 +1798,17 @@ dependencies = [
"futures-io",
]
[[package]]
name = "pkcs1"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f"
dependencies = [
"der",
"pkcs8",
"spki",
]
[[package]]
name = "pkcs8"
version = "0.10.2"
@@ -1658,6 +1839,17 @@ dependencies = [
"windows-sys 0.61.1",
]
[[package]]
name = "poly1305"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
dependencies = [
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "potential_utf"
version = "0.1.4"
@@ -1915,12 +2107,32 @@ version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94"
dependencies = [
"base64",
"base64 0.21.7",
"bitflags 2.9.4",
"serde",
"serde_derive",
]
[[package]]
name = "rsa"
version = "0.9.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40a0376c50d0358279d9d643e4bf7b7be212f1f4ff1da9070a7b54d22ef75c88"
dependencies = [
"const-oid",
"digest",
"num-bigint-dig",
"num-integer",
"num-traits",
"pkcs1",
"pkcs8",
"rand_core 0.6.4",
"signature",
"spki",
"subtle",
"zeroize",
]
[[package]]
name = "rust-ini"
version = "0.21.3"
@@ -2002,7 +2214,7 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c"
dependencies = [
"base64",
"base64 0.21.7",
]
[[package]]
@@ -2212,6 +2424,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.3"
@@ -2601,6 +2819,16 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"

View File

@@ -18,9 +18,11 @@ akd = "0.11.0"
async-trait = "0.1.89"
akd_storage = { path = "crates/akd_storage" }
bitwarden-akd-configuration = { path = "crates/bitwarden-akd-configuration" }
blake3 = "1.8.2"
common = { path = "crates/common" }
config = "0.15.18"
serde = { version = "1.0.228", features = ["derive"] }
tokio = { version = "1.47.1", features = ["full"] }
tracing = { version = "0.1.41" }
tracing-subscriber = {version = "0.3.19" }
tracing-subscriber = { version = "0.3.19" }
thiserror = "2.0.17"

View File

@@ -6,6 +6,7 @@ use akd_storage::db_config::DbConfig;
use common::VrfStorageType;
#[tokio::main]
#[allow(unreachable_code)]
async fn main() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)

View File

@@ -8,11 +8,21 @@ keywords.workspace = true
[dependencies]
akd = "0.11.0"
async-trait = { workspace = true}
async-trait.workspace = true
base64 = "0.22.1"
bitwarden-encoding = { path = "../bitwarden-encoding" }
blake3.workspace = true
chacha20poly1305 = { version = "0.10.1" }
ed25519-dalek = { version = ">=2.1.1, <=2.2.0", features = ["rand_core"] }
ms_database = { path = "../ms_database" }
rand = ">=0.8.5, <0.9"
rsa = { version = ">=0.9.2, <0.10" }
serde = { workspace = true }
thiserror = "2.0.17"
thiserror.workspace = true
tracing.workspace = true
[dev-dependencies]
tokio.workspace = true
[lints]
workspace = true

View File

@@ -0,0 +1 @@
DROP TABLE IF EXISTS dbo.vrf_key;

View File

@@ -0,0 +1,11 @@
IF OBJECT_ID('dbo.vrf_key', 'U') IS NULL
BEGIN
CREATE TABLE dbo.vrf_key (
root_key_hash VARBINARY(32) NOT NULL,
root_key_type INT NOT NULL,
enc_sym_key VARBINARY(32) NULL,
enc_sym_key_nonce VARBINARY(24) NULL,
sym_enc_vrf_key VARBINARY(32) NOT NULL,
PRIMARY KEY (root_key_hash, root_key_type)
);
END

View File

@@ -4,16 +4,16 @@ use crate::db_config::DbConfig;
#[derive(Debug, Clone, Deserialize)]
pub struct AkdStorageConfig {
db_config: DbConfig,
_db_config: DbConfig,
/// Controls how long items stay in cache before being evicted (in milliseconds). Defaults to 30 seconds.
#[serde(default = "default_cache_item_lifetime_ms")]
cache_item_lifetime_ms: usize,
_cache_item_lifetime_ms: usize,
/// Controls the maximum size of the cache in bytes. Defaults to no limit.
#[serde(default)]
cache_limit_bytes: Option<usize>,
_cache_limit_bytes: Option<usize>,
/// Controls how often the cache is cleaned (in milliseconds). Defaults to 15 seconds.
#[serde(default = "default_cache_clean_ms")]
cache_clean_ms: usize,
_cache_clean_ms: usize,
}
fn default_cache_item_lifetime_ms() -> usize {

View File

@@ -15,6 +15,8 @@ use crate::ms_sql::MsSql;
pub mod akd_storage_config;
pub mod db_config;
pub mod ms_sql;
mod vrf_key_config;
pub mod vrf_key_database;
/// Enum to represent different database types supported by the storage layer.
/// Each variant is cheap to clone for reuse across threads.

View File

@@ -98,10 +98,10 @@ impl SqlParams {
.collect()
}
pub fn values(&self) -> Vec<&(dyn ToSql)> {
pub fn values(&self) -> Vec<&dyn ToSql> {
self.params
.iter()
.map(|b| b.data.as_ref() as &(dyn ToSql))
.map(|b| b.data.as_ref() as &dyn ToSql)
.collect()
}
}

View File

@@ -110,8 +110,8 @@ pub(crate) trait AkdStorableForMsSql {
fn get_statement<St: Storable>(key: &St::StorageKey) -> Result<Statement, StorageError>;
fn get_batch_temp_table_rows<St: Storable>(
key: &[St::StorageKey],
) -> Result<Vec<TokenRow>, StorageError>;
key: &'_ [St::StorageKey],
) -> Result<Vec<TokenRow<'_>>, StorageError>;
fn get_batch_statement<St: Storable>() -> String;
@@ -119,7 +119,7 @@ pub(crate) trait AkdStorableForMsSql {
where
Self: Sized;
fn into_row(&self) -> Result<TokenRow, StorageError>;
fn into_row(&'_ self) -> Result<TokenRow<'_>, StorageError>;
}
impl AkdStorableForMsSql for DbRecord {
@@ -479,8 +479,8 @@ impl AkdStorableForMsSql for DbRecord {
}
fn get_batch_temp_table_rows<St: Storable>(
key: &[St::StorageKey],
) -> Result<Vec<TokenRow>, StorageError> {
key: &'_ [St::StorageKey],
) -> Result<Vec<TokenRow<'_>>, StorageError> {
match St::data_type() {
StorageType::Azks => Err(StorageError::Other(
"Batch temp table rows not supported for Azks".to_string(),
@@ -681,7 +681,7 @@ impl AkdStorableForMsSql for DbRecord {
}
}
fn into_row(&self) -> Result<TokenRow, StorageError> {
fn into_row(&'_ self) -> Result<TokenRow<'_>, StorageError> {
match &self {
DbRecord::Azks(azks) => {
let row = (

View File

@@ -1,3 +1,4 @@
pub(crate) mod akd_storable_for_ms_sql;
pub(crate) mod temp_table;
pub(crate) mod values;
pub(crate) mod vrf_key;

View File

@@ -0,0 +1,6 @@
use crate::vrf_key_database::VrfKeyTableData;
#[allow(unused)]
pub async fn get_vrf_key(root_key: &[u8]) -> VrfKeyTableData {
todo!()
}

View File

@@ -0,0 +1,38 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum VrfKeyConfig {
/// **WARNING**: Do not use this in production systems. This is only for testing and debugging.
/// This is a version of VRFKeyStorage for testing purposes, which uses the example from the VRF crate.
///
/// const KEY_MATERIAL: &str = "c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721";
#[cfg(test)]
ConstantVrfKey,
/// The root key is a valid and random chacha20poly1305 symmetric key directly. The provided string will be decoded
/// from base64 to produce a key.
///
/// For VRF Key generation, a random VRF private key will be sampled, encrypted with this symmetric
/// key, and the resulting encrypted VRF key will be stored. The symmetric key will not be persisted.
///
/// For VRF Key retrieval, the symmetric key will be hashed to derive a root key identifier. This will be used
/// to lookup an associated VRF key. If none is found, the application will error if a VRF key already exists.
/// Otherwise it goes on to generate a new VRF key. If a VRF key is found, it will be decrypted using this
/// symmetric key
///
/// Losing this key is equivalent to losing your directory's VRF key.
B64EncodedSymmetricKey { key: String },
/// The root key is an asymmetric RSA key. The provided string will be decoded from pkcs1 PEM to produce a private RSA key.
///
/// For VRF key generation, a random VRF private key will be sampled, a random symmetric key will be sampled,
/// the VRF key will be encrypted with the symmetric key, and the symmetric key will be encrypted with the RSA public key.
/// The resulting encrypted VRF key and encrypted symmetric key will be stored. The RSA private key will not be persisted.
///
/// For VRF key retrieval, the RSA private key will be hashed to derive a root key identifier. This will be used
/// to lookup an associated VRF key. If None is found, the application will error if a VRF key already exists.
/// Otherwise it goes on to generate a new VRF key. If a VRF key is found, the symmetric key will be decrypted using
/// the RSA private key, and then the VRF key will be decrypted using the symmetric key.
///
/// Losing this key is equivalent to losing your directory's VRF key.
PEMEncodedRSAKey { private_key: String },
}

View File

@@ -0,0 +1,259 @@
use std::str::FromStr;
use bitwarden_encoding::NotB64EncodedError;
use chacha20poly1305::{
aead::{generic_array::GenericArray, Aead},
AeadCore, KeyInit, XChaCha20Poly1305,
};
use rsa::{
pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey},
signature::digest::crypto_common,
Pkcs1v15Encrypt,
};
use thiserror::Error;
use crate::vrf_key_config::VrfKeyConfig;
/// Represents a storage-layer error
#[derive(Debug, Error)]
pub enum VrfKeyStorageError {
/// No VRF key exists for the given root key
#[error("VRF key not found for the specified root key")]
KeyNotFound,
/// A VRF key already exists, but for a different root key
#[error("A VRF key already exists for a different root key")]
KeyExistsForDifferentRootKey,
/// A transaction error
#[error("Database transaction failed: {0}")]
Transaction(String),
/// Some kind of storage connection error occurred
#[error("Storage connection error: {0}")]
Connection(String),
/// Base64 decoding error
#[error("Failed to decode base64 data: {0}")]
B64DecodingError(#[from] NotB64EncodedError),
/// ChaCha20Poly1305 length error
#[error("Invalid key length error: {0}")]
KeyLengthError(#[from] crypto_common::InvalidLength),
/// Symmetric encryption/decryption error
#[error("Symmetric encryption/decryption error")]
SymmetricEncryptionError,
/// RSA error
#[error("RSA key encoding error: {0}")]
RsaKeyEncodingError(#[from] rsa::pkcs1::Error),
#[error("RSA encryption/decryption error: {0}")]
RsaEncryptionError(#[from] rsa::Error),
/// Some other storage-layer error occurred
#[error("Storage error: {0}")]
Other(&'static str),
}
#[allow(unused)]
trait VrfKeyTable {
async fn get_vrf_key(root_key: &[u8]) -> Result<VrfKeyTableData, VrfKeyStorageError>;
async fn store_vrf_key(root_key: &[u8]) -> Result<(), VrfKeyStorageError>;
}
pub struct VrfKeyTableData {
pub root_key_hash: Vec<u8>,
pub root_key_type: RootKeyType,
pub enc_sym_key: Option<Vec<u8>>,
pub sym_enc_vrf_key: Vec<u8>,
pub sym_enc_vrf_key_nonce: Vec<u8>,
}
pub enum RootKeyType {
#[cfg(test)]
None = 0,
SymmetricKey = 1,
RsaKey = 2,
}
pub struct VrfKey(pub Vec<u8>);
impl VrfKeyTableData {
pub async fn new(config: VrfKeyConfig) -> Result<(Self, VrfKey), VrfKeyStorageError> {
// handle constant key case separately to avoid unnecessary key generation / parsing
#[cfg(test)]
if let VrfKeyConfig::ConstantVrfKey = config {
use akd::ecvrf::{HardCodedAkdVRF, VRFKeyStorage};
return Ok((
VrfKeyTableData {
root_key_hash: vec![],
root_key_type: RootKeyType::None,
enc_sym_key: None,
sym_enc_vrf_key: vec![],
sym_enc_vrf_key_nonce: vec![],
},
VrfKey((HardCodedAkdVRF {}).retrieve().await.unwrap_or_default()),
));
}
let (sym, sym_key) = if let VrfKeyConfig::B64EncodedSymmetricKey { key } = &config {
let raw_key = bitwarden_encoding::B64::from_str(key)?;
(
XChaCha20Poly1305::new_from_slice(raw_key.as_bytes())?,
raw_key.as_bytes().to_vec(),
)
} else {
let key = XChaCha20Poly1305::generate_key(rand::thread_rng());
(XChaCha20Poly1305::new(&key), key.to_vec())
};
let vrf_key = ed25519_dalek::SigningKey::generate(&mut rand::thread_rng())
.to_bytes()
.to_vec();
let nonce = XChaCha20Poly1305::generate_nonce(&mut rand::thread_rng());
let sym_enc_vrf_key = sym
.encrypt(&nonce, &vrf_key[..])
.map_err(|_| VrfKeyStorageError::SymmetricEncryptionError)?;
match config {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
VrfKeyConfig::B64EncodedSymmetricKey { key: _ } => {
let root_key_hash = blake3::hash(&sym_key).as_bytes().to_vec();
Ok((
VrfKeyTableData {
root_key_hash,
root_key_type: RootKeyType::SymmetricKey,
enc_sym_key: None,
sym_enc_vrf_key,
sym_enc_vrf_key_nonce: nonce.to_vec(),
},
VrfKey(vrf_key),
))
}
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(&private_key)?;
let root_key_hash = blake3::hash(rsa_private_key.to_pkcs1_der()?.as_bytes())
.as_bytes()
.to_vec();
let rsa_public_key = rsa_private_key.to_public_key();
let enc_sym_key =
rsa_public_key.encrypt(&mut rand::thread_rng(), Pkcs1v15Encrypt, &sym_key)?;
Ok((
VrfKeyTableData {
root_key_hash,
root_key_type: RootKeyType::RsaKey,
enc_sym_key: Some(enc_sym_key),
sym_enc_vrf_key,
sym_enc_vrf_key_nonce: nonce.to_vec(),
},
VrfKey(vrf_key),
))
}
}
}
pub async fn to_vrf_key(&self, config: VrfKeyConfig) -> Result<VrfKey, VrfKeyStorageError> {
// handle constant key case separately to avoid unnecessary key generation / parsing
#[cfg(test)]
if let VrfKeyConfig::ConstantVrfKey = config {
use akd::ecvrf::{HardCodedAkdVRF, VRFKeyStorage};
return Ok(VrfKey(
(HardCodedAkdVRF {}).retrieve().await.unwrap_or_default(),
));
}
let nonce = GenericArray::from_slice(self.sym_enc_vrf_key_nonce.as_ref());
let vrf_key = match config {
#[cfg(test)]
VrfKeyConfig::ConstantVrfKey => unreachable!(), // handled above
VrfKeyConfig::B64EncodedSymmetricKey { key } => {
let raw_key = bitwarden_encoding::B64::from_str(&key)?;
let sym = XChaCha20Poly1305::new_from_slice(raw_key.as_bytes())?;
let vrf_key = sym
.decrypt(nonce, &self.sym_enc_vrf_key[..])
.map_err(|_| VrfKeyStorageError::SymmetricEncryptionError)?;
vrf_key
}
VrfKeyConfig::PEMEncodedRSAKey { private_key } => {
let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(&private_key)?;
let enc_sym_key = self.enc_sym_key.as_ref().ok_or(VrfKeyStorageError::Other(
"missing encrypted symmetric key for RSA root key",
))?;
let sym_key = rsa_private_key.decrypt(Pkcs1v15Encrypt, enc_sym_key)?;
let sym = XChaCha20Poly1305::new_from_slice(&sym_key)?;
let vrf_key = sym
.decrypt(&nonce, &self.sym_enc_vrf_key[..])
.map_err(|_| VrfKeyStorageError::SymmetricEncryptionError)?;
vrf_key
}
};
Ok(VrfKey(vrf_key))
}
}
#[cfg(test)]
mod tests {
use rsa::{pkcs1::DecodeRsaPrivateKey, Pkcs1v15Encrypt};
#[tokio::test]
pub async fn test_generation_from_symmetric_key() {
let symmetric_key_b64 = "4AD95tg8tfveioyS/E2jAQw06FDTUCu+VSEZxa41wuM=";
let config = super::VrfKeyConfig::B64EncodedSymmetricKey {
key: symmetric_key_b64.to_string(),
};
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
assert_eq!(table_data.enc_sym_key, None);
assert_eq!(
table_data.root_key_hash,
[
130, 153, 58, 122, 202, 166, 92, 56, 249, 28, 57, 171, 206, 187, 12, 81, 44, 166,
61, 41, 188, 84, 20, 43, 108, 211, 146, 152, 243, 155, 49, 66
]
);
assert_eq!(vrf_key.0, retrieved_vrf_key.0);
}
#[tokio::test]
pub async fn test_generation_from_rsa_key() {
let rsa_private_key_pem = r"-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCaPQBvavQC8o/A0map70QTqGz6ETMURzHaWIEjlS89ytjj+8Zs
K9L1HCy9SOShFcSYrGb47CdMhMKHa/1YRUVA653uO4rqlO+wPhOZEzljvp9zXvDz
ybLjF2aGZg61w1rC25l36M0NUx8HN+Ws+14mcVzllUiXbk9PMXhWFKoj2wIDAQAB
AoGAU61Sph/NQCgea0r6nakMMuoGLWjVYGP7nOy1KvvNxGVfY9h9XsQr0AS4FP0N
5IKtxPKLbvKXo4DHFLc2nAQAvI8kUPZM40jyVk2yUr2k48PMkssdQKXJ/qRi6PeI
LLLSh7IHDYWdVL7pHA1a7ghH+DIATkA83/++QON1btyKSNECQQDMkKZhqjP2OAbW
5xYrmJp3Q2TlXRjwuOdZLD8uXHl15vAxGokkawxkVlW5vI99tdnqS6Kp5U0THP6H
jc+Hii85AkEAwQTxM1Nr3McluiS5kXs8FjdlgUJ+zRAZWOHQqEazQXDlXFVODHFO
+Rh2sX9eqFUc07sJyjV1xLoN5Fe8DjUXswJABy91iKyv0pA5PUc0sidUFahaXOwe
OiZkie9R8NDyuz93ZGIoOw0/jC60KCgFakb+9ondltYlFOzJy/0hMwOZkQJAc+rB
5+8LcfVvZNC1WPdHaJgwL2Z9vC0U69oBc22yLXTdaYwZaUOLB/F3JrW1ZSZoP4eu
I2/joBeUTDOcTnP4HQJBAICmnHCopJ1sSfQG3fMDobOStJBvxQwLkGeRGzI2XsMw
k7UXX8Wh7AgrK4A/MuZXJL30Cd/dgtlHzJWtlQevTII=
-----END RSA PRIVATE KEY-----";
let rsa_private_key = rsa::RsaPrivateKey::from_pkcs1_pem(rsa_private_key_pem).unwrap();
let config = super::VrfKeyConfig::PEMEncodedRSAKey {
private_key: rsa_private_key_pem.to_string(),
};
let (table_data, vrf_key) = super::VrfKeyTableData::new(config.clone()).await.unwrap();
let retrieved_vrf_key = table_data.to_vrf_key(config).await.unwrap();
assert_eq!(
table_data.root_key_hash,
[
124, 52, 131, 164, 108, 28, 127, 165, 58, 31, 40, 199, 182, 120, 247, 152, 191,
169, 215, 215, 230, 71, 154, 182, 30, 62, 209, 234, 2, 112, 150, 128
]
);
assert!(table_data.enc_sym_key.is_some());
let _ = rsa_private_key
.decrypt(Pkcs1v15Encrypt, table_data.enc_sym_key.as_ref().unwrap())
.unwrap();
assert_eq!(vrf_key.0, retrieved_vrf_key.0);
}
}

View File

@@ -1,8 +1,8 @@
use akd::ecvrf::HardCodedAkdVRF;
use akd::storage::StorageManager;
use akd::Directory;
use akd_storage::DatabaseType;
use akd_storage::db_config::DbConfig;
use akd_storage::DatabaseType;
use anyhow::{Context, Result};
use clap::{Parser, ValueEnum};
use commands::Command;
@@ -65,7 +65,10 @@ enum Mode {
}
#[derive(Parser, Debug, Clone)]
#[clap(name = "akd-test-utility", about = "AKD MS SQL test utility and benchmark tool")]
#[clap(
name = "akd-test-utility",
about = "AKD MS SQL test utility and benchmark tool"
)]
struct CliArgs {
/// Database connection string (also reads from AKD_MSSQL_CONNECTION_STRING env var)
#[clap(long = "connection-string", short = 'c')]
@@ -82,7 +85,11 @@ struct CliArgs {
log_level: LogLevel,
/// Optional log file path (suppresses console logging when specified)
#[clap(long = "log-file", short = 'f', help = "Write logs to file (suppresses console output)")]
#[clap(
long = "log-file",
short = 'f',
help = "Write logs to file (suppresses console output)"
)]
log_file: Option<String>,
/// Connection pool size
@@ -140,10 +147,11 @@ async fn main() -> Result<()> {
use tracing_subscriber::util::SubscriberInitExt;
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| {
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
tracing_subscriber::EnvFilter::new(args.log_level.to_tracing_level().as_str())
}))
}),
)
.with(layers)
.init();
@@ -151,9 +159,15 @@ async fn main() -> Result<()> {
// Create database connection
info!("Connecting to MS SQL database");
let config = DbConfig::MsSql { connection_string, pool_size: args.pool_size };
let db = config.connect().await.context("Failed to connect to database")?;
let config = DbConfig::MsSql {
connection_string,
pool_size: args.pool_size,
};
let db = config
.connect()
.await
.context("Failed to connect to database")?;
// Handle pre-processing modes
if let Some(()) = pre_process_mode(&args, &db).await? {
return Ok(());
@@ -177,10 +191,7 @@ async fn main() -> Result<()> {
}
// Process modes that run before creating the directory
async fn pre_process_mode(
args: &CliArgs,
db: &DatabaseType,
) -> Result<Option<()>> {
async fn pre_process_mode(args: &CliArgs, db: &DatabaseType) -> Result<Option<()>> {
match (db, &args.mode) {
(DatabaseType::MsSql(db), Some(Mode::Drop)) => {
info!("Dropping database tables");
@@ -475,7 +486,7 @@ async fn bench_lookup(
}
async fn repl_loop(
args: &CliArgs,
_args: &CliArgs,
tx: &Sender<directory_host::Rpc>,
db: &DatabaseType,
) -> Result<()> {
@@ -491,7 +502,7 @@ async fn repl_loop(
(_, Command::Unknown(other)) => {
println!("Input '{other}' is not supported, enter 'help' for the help menu")
}
(_,Command::InvalidArgs(message)) => println!("Invalid arguments: {message}"),
(_, Command::InvalidArgs(message)) => println!("Invalid arguments: {message}"),
(_, Command::Exit) => {
info!("Exiting...");
break;
@@ -518,6 +529,7 @@ async fn repl_loop(
}
}
}
#[allow(unreachable_patterns)]
(_, Command::Clean) => {
println!("Clean command is only supported for MS SQL databases");
}

View File

@@ -8,7 +8,7 @@ keywords.workspace = true
[dependencies]
akd.workspace = true
blake3 = "1.8.2"
blake3.workspace = true
config = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
uuid = { version = "1.18.1", features = ["serde"] }

View File

@@ -0,0 +1,26 @@
[package]
name = "bitwarden-encoding"
description = """
Internal crate for the bitwarden crate. Do not use.
"""
edition.workspace = true
version.workspace = true
authors.workspace = true
license-file.workspace = true
keywords.workspace = true
[features]
default = []
[dependencies]
data-encoding = ">=2.0, <3"
data-encoding-macro = "0.1.18"
serde = { workspace = true }
thiserror.workspace = true
[dev-dependencies]
serde_json = ">=1.0.96, <2.0"
[lints]
workspace = true

View File

@@ -0,0 +1,3 @@
# Bitwarden Encoding
Provides Base64 and Base64Url encoding and decoding utilities for working with Bitwarden data.

View File

@@ -0,0 +1,251 @@
use std::str::FromStr;
use data_encoding::BASE64;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::FromStrVisitor;
/// Base64 encoded data
///
/// Is indifferent about padding when decoding, but always produces padding when encoding.
#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
#[serde(into = "String")]
pub struct B64(Vec<u8>);
impl B64 {
/// Returns a byte slice of the inner vector.
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Returns the inner byte vector.
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
}
// We manually implement this to handle both `String` and `&str`
impl<'de> Deserialize<'de> for B64 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(FromStrVisitor::new())
}
}
impl From<Vec<u8>> for B64 {
fn from(src: Vec<u8>) -> Self {
Self(src)
}
}
impl From<&[u8]> for B64 {
fn from(src: &[u8]) -> Self {
Self(src.to_vec())
}
}
impl From<B64> for Vec<u8> {
fn from(src: B64) -> Self {
src.0
}
}
impl From<B64> for String {
fn from(src: B64) -> Self {
String::from(&src)
}
}
impl From<&B64> for String {
fn from(src: &B64) -> Self {
BASE64.encode(&src.0)
}
}
impl std::fmt::Display for B64 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(String::from(self).as_str())
}
}
/// An error returned when a string is not base64 decodable.
#[derive(Debug, Error)]
#[error("Data isn't base64 encoded")]
pub struct NotB64EncodedError;
const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
padding: None,
check_trailing_bits: false,
};
const BASE64_PADDING: &str = "=";
impl TryFrom<String> for B64 {
type Error = NotB64EncodedError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl TryFrom<&str> for B64 {
type Error = NotB64EncodedError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let sane_string = value.trim_end_matches(BASE64_PADDING);
BASE64_PERMISSIVE
.decode(sane_string.as_bytes())
.map(Self)
.map_err(|_| NotB64EncodedError)
}
}
impl FromStr for B64 {
type Err = NotB64EncodedError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_b64_from_vec() {
let data = vec![72, 101, 108, 108, 111];
let b64 = B64::from(data.clone());
assert_eq!(Vec::<u8>::from(b64), data);
}
#[test]
fn test_b64_from_slice() {
let data = b"Hello";
let b64 = B64::from(data.as_slice());
assert_eq!(b64.as_bytes(), data);
}
#[test]
fn test_b64_encoding_with_padding() {
let data = b"Hello, World!";
let b64 = B64::from(data.as_slice());
let encoded = String::from(&b64);
assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
assert!(encoded.contains('='));
}
#[test]
fn test_b64_decoding_with_padding() {
let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
let b64 = B64::try_from(encoded_with_padding).unwrap();
assert_eq!(b64.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64_decoding_without_padding() {
let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
let b64 = B64::try_from(encoded_without_padding).unwrap();
assert_eq!(b64.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64_round_trip_with_padding() {
let original = b"Test data that requires padding!";
let b64 = B64::from(original.as_slice());
let encoded = String::from(&b64);
let decoded = B64::try_from(encoded.as_str()).unwrap();
assert_eq!(decoded.as_bytes(), original);
}
#[test]
fn test_b64_round_trip_without_padding() {
let original = b"Test data";
let b64 = B64::from(original.as_slice());
let encoded = String::from(&b64);
let decoded = B64::try_from(encoded.as_str()).unwrap();
assert_eq!(decoded.as_bytes(), original);
}
#[test]
fn test_b64_display() {
let data = b"Hello";
let b64 = B64::from(data.as_slice());
assert_eq!(b64.to_string(), "SGVsbG8=");
}
#[test]
fn test_b64_invalid_encoding() {
let invalid_b64 = "This is not base64!@#$";
let result = B64::try_from(invalid_b64);
assert!(result.is_err());
}
#[test]
fn test_b64_empty_string() {
let empty = "";
let b64 = B64::try_from(empty).unwrap();
assert_eq!(b64.as_bytes().len(), 0);
}
#[test]
fn test_b64_padding_removal() {
let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
let b64 = B64::try_from(encoded_with_padding).unwrap();
assert_eq!(b64.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64_serialization() {
let data = b"serialization test";
let b64 = B64::from(data.as_slice());
let serialized = serde_json::to_string(&b64).unwrap();
assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
assert_eq!(b64.as_bytes(), deserialized.as_bytes());
}
#[test]
fn test_not_b64_encoded_error_display() {
let error = NotB64EncodedError;
assert_eq!(error.to_string(), "Data isn't base64 encoded");
}
#[test]
fn test_b64_from_str() {
let encoded = "SGVsbG8sIFdvcmxkIQ==";
let b64: B64 = encoded.parse().unwrap();
assert_eq!(b64.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64_eq_and_hash() {
let data1 = b"test data";
let data2 = b"test data";
let data3 = b"different data";
let b64_1 = B64::from(data1.as_slice());
let b64_2 = B64::from(data2.as_slice());
let b64_3 = B64::from(data3.as_slice());
assert_eq!(b64_1, b64_2);
assert_ne!(b64_1, b64_3);
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
let mut hasher1 = DefaultHasher::new();
let mut hasher2 = DefaultHasher::new();
b64_1.hash(&mut hasher1);
b64_2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
}
}

View File

@@ -0,0 +1,238 @@
use std::str::FromStr;
use data_encoding::BASE64URL_NOPAD;
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Base64URL encoded data
///
/// Is indifferent about padding when decoding, but always produces padding when encoding.
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
#[serde(try_from = "&str", into = "String")]
pub struct B64Url(Vec<u8>);
impl B64Url {
/// Returns a byte slice of the inner vector.
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Returns the inner byte vector.
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
}
impl From<Vec<u8>> for B64Url {
fn from(src: Vec<u8>) -> Self {
Self(src)
}
}
impl From<&[u8]> for B64Url {
fn from(src: &[u8]) -> Self {
Self(src.to_vec())
}
}
impl From<B64Url> for Vec<u8> {
fn from(src: B64Url) -> Self {
src.0
}
}
impl From<B64Url> for String {
fn from(src: B64Url) -> Self {
String::from(&src)
}
}
impl From<&B64Url> for String {
fn from(src: &B64Url) -> Self {
BASE64URL_NOPAD.encode(&src.0)
}
}
impl std::fmt::Display for B64Url {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(String::from(self).as_str())
}
}
/// An error returned when a string is not base64 decodable.
#[derive(Debug, Error)]
#[error("Data isn't base64url encoded")]
pub struct NotB64UrlEncodedError;
const BASE64URL_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
padding: None,
check_trailing_bits: false,
};
const BASE64URL_PADDING: &str = "=";
impl TryFrom<String> for B64Url {
type Error = NotB64UrlEncodedError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl TryFrom<&str> for B64Url {
type Error = NotB64UrlEncodedError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let sane_string = value.trim_end_matches(BASE64URL_PADDING);
BASE64URL_PERMISSIVE
.decode(sane_string.as_bytes())
.map(Self)
.map_err(|_| NotB64UrlEncodedError)
}
}
impl FromStr for B64Url {
type Err = NotB64UrlEncodedError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_b64url_from_vec() {
let data = vec![72, 101, 108, 108, 111];
let b64url = B64Url::from(data.clone());
assert_eq!(Vec::<u8>::from(b64url), data);
}
#[test]
fn test_b64url_from_slice() {
let data = b"Hello";
let b64url = B64Url::from(data.as_slice());
assert_eq!(b64url.as_bytes(), data);
}
#[test]
fn test_b64url_encoding_with_padding() {
let data = b"Hello, World!";
let b64url = B64Url::from(data.as_slice());
let encoded = String::from(&b64url);
assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ");
}
#[test]
fn test_b64url_decoding_with_padding() {
let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
let b64url = B64Url::try_from(encoded_with_padding).unwrap();
assert_eq!(b64url.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64url_decoding_without_padding() {
let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
let b64url = B64Url::try_from(encoded_without_padding).unwrap();
assert_eq!(b64url.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64url_round_trip_with_padding() {
let original = b"Test data that requires padding!";
let b64url = B64Url::from(original.as_slice());
let encoded = String::from(&b64url);
let decoded = B64Url::try_from(encoded.as_str()).unwrap();
assert_eq!(decoded.as_bytes(), original);
}
#[test]
fn test_b64url_round_trip_without_padding() {
let original = b"Test data";
let b64url = B64Url::from(original.as_slice());
let encoded = String::from(&b64url);
let decoded = B64Url::try_from(encoded.as_str()).unwrap();
assert_eq!(decoded.as_bytes(), original);
}
#[test]
fn test_b64url_display() {
let data = b"Hello";
let b64url = B64Url::from(data.as_slice());
assert_eq!(b64url.to_string(), "SGVsbG8");
}
#[test]
fn test_b64url_invalid_encoding() {
let invalid_b64url = "This is not base64url!@#$";
let result = B64Url::try_from(invalid_b64url);
assert!(result.is_err());
}
#[test]
fn test_b64url_empty_string() {
let empty = "";
let b64url = B64Url::try_from(empty).unwrap();
assert_eq!(b64url.as_bytes().len(), 0);
}
#[test]
fn test_b64url_padding_removal() {
let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
let b64url = B64Url::try_from(encoded_with_padding).unwrap();
assert_eq!(b64url.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64url_serialization() {
let data = b"serialization test";
let b64url = B64Url::from(data.as_slice());
let serialized = serde_json::to_string(&b64url).unwrap();
assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
let deserialized: B64Url = serde_json::from_str(&serialized).unwrap();
assert_eq!(b64url.as_bytes(), deserialized.as_bytes());
}
#[test]
fn test_not_b64url_encoded_error_display() {
let error = NotB64UrlEncodedError;
assert_eq!(error.to_string(), "Data isn't base64url encoded");
}
#[test]
fn test_b64url_from_str() {
let encoded = "SGVsbG8sIFdvcmxkIQ==";
let b64url: B64Url = encoded.parse().unwrap();
assert_eq!(b64url.as_bytes(), b"Hello, World!");
}
#[test]
fn test_b64url_eq_and_hash() {
let data1 = b"test data";
let data2 = b"test data";
let data3 = b"different data";
let b64url_1 = B64Url::from(data1.as_slice());
let b64url_2 = B64Url::from(data2.as_slice());
let b64url_3 = B64Url::from(data3.as_slice());
assert_eq!(b64url_1, b64url_2);
assert_ne!(b64url_1, b64url_3);
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
let mut hasher1 = DefaultHasher::new();
let mut hasher2 = DefaultHasher::new();
b64url_1.hash(&mut hasher1);
b64url_2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
}
}

View File

@@ -0,0 +1,9 @@
#![doc = include_str!("../README.md")]
mod b64;
mod b64url;
mod serde;
pub use b64::{NotB64EncodedError, B64};
pub use b64url::{B64Url, NotB64UrlEncodedError};
pub use serde::FromStrVisitor;

View File

@@ -0,0 +1,32 @@
use std::str::FromStr;
/// A serde visitor that converts a string to a type that implements `FromStr`.
pub struct FromStrVisitor<T>(std::marker::PhantomData<T>);
impl<T> FromStrVisitor<T> {
/// Create a new `FromStrVisitor` for the given type.
pub fn new() -> Self {
Self::default()
}
}
impl<T> Default for FromStrVisitor<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T: FromStr> serde::de::Visitor<'_> for FromStrVisitor<T>
where
T::Err: std::fmt::Debug,
{
type Value = T;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "a valid string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
T::from_str(v).map_err(|e| E::custom(format!("{e:?}")))
}
}

View File

@@ -9,10 +9,10 @@ keywords.workspace = true
[dependencies]
akd = "0.11.0"
async-trait = { workspace = true }
akd_storage = { workspace = true}
akd_storage = { workspace = true }
config = { workspace = true }
serde = { workspace = true }
thiserror = "2.0.17"
thiserror.workspace = true
tracing.workspace = true
hex = "0.4.3"

View File

@@ -10,7 +10,7 @@ keywords.workspace = true
async-trait = { workspace = true }
bb8 = "0.9.0"
macros = { path = "../macros" }
thiserror = "2.0.17"
thiserror.workspace = true
tokio = { workspace = true }
tokio-util = { version = "0.7.16", features = ["compat"] }
tracing = { workspace = true }

View File

@@ -5,7 +5,7 @@ use tokio_util::compat::TokioAsyncWriteCompatExt;
use bb8::ManageConnection;
use tiberius::{Client, Config};
use tracing::{debug, instrument, info};
use tracing::{debug, info, instrument};
#[derive(thiserror::Error, Debug)]
pub enum OnConnectError {
@@ -66,7 +66,7 @@ impl ManagedConnection {
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn tiberius::ToSql)],
params: &[&dyn tiberius::ToSql],
) -> Result<tiberius::ExecuteResult, tiberius::error::Error> {
debug!("Executing command");
self.0.execute(sql, params).await
@@ -76,7 +76,7 @@ impl ManagedConnection {
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn tiberius::ToSql)],
params: &[&dyn tiberius::ToSql],
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
debug!("Executing query");
self.0.query(sql, params).await
@@ -141,8 +141,6 @@ impl ManageConnection for ConnectionManager {
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
*self.is_healthy
.read()
.expect("poisoned is_healthy lock")
*self.is_healthy.read().expect("poisoned is_healthy lock")
}
}

View File

@@ -5,14 +5,14 @@ use common::VrfStorageType;
use tracing::instrument;
struct AppState {
directory: Directory<BitwardenV1Configuration, DatabaseType, VrfStorageType>,
_directory: Directory<BitwardenV1Configuration, DatabaseType, VrfStorageType>,
}
#[instrument(skip_all, name = "publisher_start")]
pub async fn start_write_job(_db: DatabaseType, vrf: VrfStorageType) {
let storage_manager = StorageManager::new_no_cache(_db);
let _app_state = AppState {
directory: Directory::new(storage_manager, vrf).await.unwrap(),
_directory: Directory::new(storage_manager, vrf).await.unwrap(),
};
println!("Publisher started");
}

View File

@@ -1,21 +1,21 @@
use std::sync::Arc;
use akd::{directory::ReadOnlyDirectory, storage::StorageManager};
use tracing::instrument;
use bitwarden_akd_configuration::BitwardenV1Configuration;
use akd_storage::DatabaseType;
use bitwarden_akd_configuration::BitwardenV1Configuration;
use common::VrfStorageType;
use tracing::instrument;
struct AppState {
// Add any shared state here, e.g., database connections
directory: ReadOnlyDirectory<BitwardenV1Configuration, DatabaseType, VrfStorageType>,
_directory: ReadOnlyDirectory<BitwardenV1Configuration, DatabaseType, VrfStorageType>,
}
#[instrument(skip_all, name = "reader_start")]
pub async fn start(db: DatabaseType, vrf: VrfStorageType) {
let storage_manager = StorageManager::new_no_cache(db);
let _app = AppState {
directory: ReadOnlyDirectory::new(storage_manager, vrf).await.unwrap(),
_directory: ReadOnlyDirectory::new(storage_manager, vrf)
.await
.expect("Failed to create ReadOnlyDirectory"),
};
println!("Reader started");
}

View File

@@ -1,7 +1,6 @@
//! The Reader crate is responsible for handling read requests to the AKD. It requires only read permissions to the
//! underlying data stores, and can be horizontally scaled as needed.
use akd::ecvrf::VRFKeyStorage;
use akd_storage::db_config::DbConfig;
use common::VrfStorageType;
use reader::start;

View File

@@ -1,6 +1,6 @@
[toolchain]
channel = "1.88.0"
components = [ "rustfmt", "clippy" ]
channel = "1.90.0"
components = ["rustfmt", "clippy"]
profile = "minimal"
# The following is not part of the rust-toolchain.toml format,