1
0
mirror of https://github.com/bitwarden/server synced 2026-01-30 00:03:48 +00:00

Move mssql db to sub directory

This commit is contained in:
Matt Gibson
2025-12-10 12:47:28 -08:00
parent 8b750aee0a
commit 6e1758afde
9 changed files with 82 additions and 39 deletions

View File

@@ -12,14 +12,9 @@ use async_trait::async_trait;
use crate::ms_sql::MsSql;
mod migrations;
pub mod ms_sql;
mod ms_sql_storable;
mod sql_params;
mod tables;
mod temp_table;
pub mod akd_storage_config;
pub mod db_config;
pub mod ms_sql;
/// Enum to represent different database types supported by the storage layer.
/// Each variant is cheap to clone for reuse across threads.

View File

@@ -1,6 +1,9 @@
pub(crate) mod migrations;
pub(crate) mod sql_params;
pub(crate) mod tables;
use std::{cmp::Ordering, collections::HashMap, sync::Arc};
use crate::migrations::{TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES};
use akd::{
errors::StorageError,
storage::{
@@ -13,11 +16,13 @@ use async_trait::async_trait;
use ms_database::{IntoRow, MsSqlConnectionManager, Pool, PooledConnection};
use tracing::{debug, error, info, instrument, trace, warn};
use crate::{
migrations::MIGRATIONS,
ms_sql_storable::{MsSqlStorable, Statement},
tables::values,
use migrations::{
MIGRATIONS, TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_MIGRATIONS, TABLE_VALUES,
};
use tables::{
akd_storable_for_ms_sql::{AkdStorableForMsSql, Statement},
temp_table::TempTable,
values,
};
const DEFAULT_POOL_SIZE: u32 = 100;
@@ -284,7 +289,7 @@ impl Database for MsSql {
// Set values from temp table to main table
debug!("Merging temp table data into main table");
let sql = <DbRecord as MsSqlStorable>::set_batch_statement(&storage_type);
let sql = <DbRecord as AkdStorableForMsSql>::set_batch_statement(&storage_type);
trace!(sql, "Batch merge SQL");
conn.simple_query(&sql).await.map_err(|e| {
error!(error = %e, "Failed to execute batch set statement");

View File

@@ -10,10 +10,11 @@ use akd::{
use ms_database::{ColumnData, IntoRow, Row, ToSql, TokenRow};
use tracing::debug;
use crate::{migrations::{
TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_VALUES
}, temp_table::TempTable};
use crate::sql_params::SqlParams;
use crate::ms_sql::{
migrations::{TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_VALUES},
sql_params::SqlParams,
tables::{temp_table::TempTable, values},
};
const SELECT_AZKS_DATA: &'static [&str] = &["epoch", "num_nodes"];
const SELECT_HISTORY_TREE_NODE_DATA: &'static [&str] = &[
@@ -76,8 +77,16 @@ pub(crate) struct QueryStatement<Out> {
}
impl<Out> QueryStatement<Out> {
pub fn new(sql: String, params: SqlParams, parser: fn(&Row) -> Result<Out, StorageError>) -> Self {
Self { sql, params, parser }
pub fn new(
sql: String,
params: SqlParams,
parser: fn(&Row) -> Result<Out, StorageError>,
) -> Self {
Self {
sql,
params,
parser,
}
}
pub fn sql(&self) -> &str {
@@ -93,7 +102,7 @@ impl<Out> QueryStatement<Out> {
}
}
pub(crate) trait MsSqlStorable {
pub(crate) trait AkdStorableForMsSql {
fn set_statement(&self) -> Result<Statement, StorageError>;
fn set_batch_statement(storage_type: &StorageType) -> String;
@@ -113,7 +122,7 @@ pub(crate) trait MsSqlStorable {
fn into_row(&self) -> Result<TokenRow, StorageError>;
}
impl MsSqlStorable for DbRecord {
impl AkdStorableForMsSql for DbRecord {
fn set_statement(&self) -> Result<Statement, StorageError> {
let record_type = match &self {
DbRecord::Azks(_) => "Azks",
@@ -123,10 +132,14 @@ impl MsSqlStorable for DbRecord {
debug!(record_type, "Generating set statement");
match &self {
DbRecord::Azks(azks) => {
debug!(epoch = azks.latest_epoch, num_nodes = azks.num_nodes, "Building AZKS set statement");
debug!(
epoch = azks.latest_epoch,
num_nodes = azks.num_nodes,
"Building AZKS set statement"
);
let mut params = SqlParams::new();
params.add("akd_key", Box::new(1i16)); // constant key
// TODO: Fixup as conversions
// TODO: Fixup as conversions
params.add("epoch", Box::new(azks.latest_epoch as i64));
params.add("num_nodes", Box::new(azks.num_nodes as i64));
@@ -142,7 +155,9 @@ impl MsSqlStorable for DbRecord {
VALUES ({});
"#,
params.keys_as_columns().join(", "),
params.set_columns_equal_except("t.", "source.", vec!["akd_key"]).join(", "),
params
.set_columns_equal_except("t.", "source.", vec!["akd_key"])
.join(", "),
params.columns().join(", "),
params.keys().join(", ")
);
@@ -264,11 +279,7 @@ impl MsSqlStorable for DbRecord {
"#,
params.keys_as_columns().join(", "),
params
.set_columns_equal_except(
"t.",
"source.",
vec!["label_len", "label_val"]
)
.set_columns_equal_except("t.", "source.", vec!["label_len", "label_val"])
.join(", "),
params.columns().join(", "),
params.keys().join(", "),
@@ -666,7 +677,7 @@ impl MsSqlStorable for DbRecord {
Ok(DbRecord::TreeNode(node))
}
StorageType::ValueState => Ok(DbRecord::ValueState(crate::tables::values::from_row(row)?)),
StorageType::ValueState => Ok(DbRecord::ValueState(values::from_row(row)?)),
}
}

View File

@@ -0,0 +1,3 @@
pub(crate) mod akd_storable_for_ms_sql;
pub(crate) mod temp_table;
pub(crate) mod values;

View File

@@ -4,9 +4,12 @@ use akd::{
AkdLabel, AkdValue,
};
use ms_database::Row;
use tracing::{debug};
use tracing::debug;
use crate::{migrations::TABLE_VALUES, ms_sql_storable::QueryStatement, sql_params::SqlParams};
use crate::ms_sql::{
migrations::TABLE_VALUES, sql_params::SqlParams,
tables::akd_storable_for_ms_sql::QueryStatement,
};
pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
debug!("Building get_all query for label (label not logged for privacy)");
@@ -20,14 +23,18 @@ pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
FROM {}
WHERE raw_label = {}
"#,
TABLE_VALUES, params
TABLE_VALUES,
params
.key_for("raw_label")
.expect("raw_label was added to the params list")
);
QueryStatement::new(sql, params, from_row)
}
pub fn get_by_flag(raw_label: &AkdLabel, flag: ValueStateRetrievalFlag) -> QueryStatement<ValueState> {
pub fn get_by_flag(
raw_label: &AkdLabel,
flag: ValueStateRetrievalFlag,
) -> QueryStatement<ValueState> {
debug!(?flag, "Building get_by_flag query with flag");
let mut params = SqlParams::new();
params.add("raw_label", Box::new(raw_label.0.clone()));
@@ -101,15 +108,39 @@ pub fn get_versions_by_flag(
let (filter, epoch_col) = match flag {
ValueStateRetrievalFlag::SpecificVersion(version) => {
params.add("version", Box::new(version as i64));
(format!("WHERE tmp.version = {}", params.key_for("version").expect("version was added to the params list")), "tmp.epoch")
(
format!(
"WHERE tmp.version = {}",
params
.key_for("version")
.expect("version was added to the params list")
),
"tmp.epoch",
)
}
ValueStateRetrievalFlag::SpecificEpoch(epoch) => {
params.add("epoch", Box::new(epoch as i64));
(format!("WHERE tmp.epoch = {}", params.key_for("epoch").expect("epoch was added to the params list")), "tmp.epoch")
(
format!(
"WHERE tmp.epoch = {}",
params
.key_for("epoch")
.expect("epoch was added to the params list")
),
"tmp.epoch",
)
}
ValueStateRetrievalFlag::LeqEpoch(epoch) => {
params.add("epoch", Box::new(epoch as i64));
(format!("WHERE tmp.epoch <= {}", params.key_for("epoch").expect("epoch was added to the params list")), "MAX(tmp.epoch)")
(
format!(
"WHERE tmp.epoch <= {}",
params
.key_for("epoch")
.expect("epoch was added to the params list")
),
"MAX(tmp.epoch)",
)
}
ValueStateRetrievalFlag::MaxEpoch => ("".to_string(), "MAX(tmp.epoch)"),
ValueStateRetrievalFlag::MinEpoch => ("".to_string(), "MIN(tmp.epoch)"),
@@ -127,8 +158,7 @@ pub fn get_versions_by_flag(
GROUP BY tmp.raw_label
) epochs on epochs.raw_label = t.raw_label AND epochs.epoch = t.epoch
"#,
epoch_col,
filter,
epoch_col, filter,
);
QueryStatement::new(sql, params, version_from_row)

View File

@@ -1 +0,0 @@
pub(crate) mod values;