mirror of
https://github.com/bitwarden/server
synced 2026-01-30 00:03:48 +00:00
First complete implementation of Database trait for sql server
This commit is contained in:
@@ -2,4 +2,5 @@ mod migrations;
|
||||
mod sql_params;
|
||||
mod ms_sql_storable;
|
||||
pub mod ms_sql;
|
||||
mod tables;
|
||||
mod temp_table;
|
||||
|
||||
@@ -1,48 +1,485 @@
|
||||
use std::collections::HashMap;
|
||||
use std::{cmp::Ordering, collections::HashMap, sync::Arc};
|
||||
|
||||
use akd::{errors::StorageError, storage::{types::{self, DbRecord}, Database, DbSetState, Storable}, AkdLabel, AkdValue};
|
||||
use akd::{
|
||||
errors::StorageError,
|
||||
storage::{
|
||||
types::{self, DbRecord, KeyData, StorageType},
|
||||
Database, DbSetState, Storable,
|
||||
},
|
||||
AkdLabel, AkdValue,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use ms_database::ConnectionManager;
|
||||
use ms_database::{IntoRow, MsSqlConnectionManager, Pool, PooledConnection};
|
||||
|
||||
use crate::{
|
||||
migrations::MIGRATIONS,
|
||||
ms_sql_storable::{MsSqlStorable, Statement},
|
||||
tables::values,
|
||||
temp_table::TempTable,
|
||||
};
|
||||
|
||||
const DEFAULT_POOL_SIZE: u32 = 100;
|
||||
|
||||
pub struct MsSqlBuilder {
|
||||
connection_string: String,
|
||||
pool_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl MsSqlBuilder {
|
||||
/// Create a new [MsSqlBuilder] with the given connection string.
|
||||
pub fn new(connection_string: String) -> Self {
|
||||
Self {
|
||||
connection_string,
|
||||
pool_size: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the connection pool size. Default is given by [DEFAULT_POOL_SIZE].
|
||||
pub fn pool_size(mut self, pool_size: u32) -> Self {
|
||||
self.pool_size = Some(pool_size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the [MsSql] instance.
|
||||
pub async fn build(self) -> Result<MsSql, StorageError> {
|
||||
let pool_size = self.pool_size.unwrap_or(DEFAULT_POOL_SIZE);
|
||||
|
||||
MsSql::new(self.connection_string, pool_size).await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MsSql {
|
||||
connection_manager: ConnectionManager,
|
||||
pool: Arc<Pool>,
|
||||
}
|
||||
|
||||
impl MsSql {
|
||||
pub fn builder(connection_string: String) -> MsSqlBuilder {
|
||||
MsSqlBuilder::new(connection_string)
|
||||
}
|
||||
|
||||
pub async fn new(connection_string: String, pool_size: u32) -> Result<Self, StorageError> {
|
||||
let connection_manager = MsSqlConnectionManager::new(connection_string);
|
||||
let pool = Pool::builder()
|
||||
.max_size(pool_size)
|
||||
.build(connection_manager)
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(format!("Failed to create DB pool: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
pool: Arc::new(pool),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn migrate(&self) -> Result<(), StorageError> {
|
||||
let mut conn = self.pool.get().await.map_err(|e| {
|
||||
StorageError::Connection(format!("Failed to get DB connection for migrations: {}", e))
|
||||
})?;
|
||||
|
||||
ms_database::run_pending_migrations(&mut conn, MIGRATIONS)
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(format!("Failed to run migrations: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_connection(&self) -> Result<PooledConnection<'_>, StorageError> {
|
||||
self.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(format!("Failed to get DB connection: {}", e)))
|
||||
}
|
||||
|
||||
async fn execute_statement(&self, statement: &Statement) -> Result<(), StorageError> {
|
||||
let mut conn = self.get_connection().await?;
|
||||
self.execute_statement_on_connection(statement, &mut conn)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_statement_on_connection(
|
||||
&self,
|
||||
statement: &Statement,
|
||||
conn: &mut PooledConnection<'_>,
|
||||
) -> Result<(), StorageError> {
|
||||
conn.execute(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to execute statement: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for MsSql {
|
||||
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
|
||||
todo!()
|
||||
let statement = record.set_statement()?;
|
||||
self.execute_statement(&statement).await
|
||||
}
|
||||
|
||||
async fn batch_set(&self, records: Vec<DbRecord>, state: DbSetState) -> Result<(), StorageError> {
|
||||
todo!()
|
||||
async fn batch_set(
|
||||
&self,
|
||||
records: Vec<DbRecord>,
|
||||
_state: DbSetState, // TODO: unused in mysql example, but may be needed later
|
||||
) -> Result<(), StorageError> {
|
||||
// Generate groups by type
|
||||
let mut groups = HashMap::new();
|
||||
for record in records {
|
||||
match &record {
|
||||
DbRecord::Azks(_) => groups
|
||||
.entry(StorageType::Azks)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(record),
|
||||
DbRecord::TreeNode(_) => groups
|
||||
.entry(StorageType::TreeNode)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(record),
|
||||
DbRecord::ValueState(_) => groups
|
||||
.entry(StorageType::ValueState)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(record),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute each group in batches
|
||||
let mut conn = self.get_connection().await?;
|
||||
// Start transaction
|
||||
conn.simple_query("BEGIN TRANSACTION")
|
||||
.await
|
||||
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
|
||||
let result = async {
|
||||
for (storage_type, mut record_group) in groups.into_iter() {
|
||||
if record_group.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sort the records to match db-layer sorting
|
||||
record_group.sort_by(|a, b| match &a {
|
||||
DbRecord::TreeNode(node) => {
|
||||
if let DbRecord::TreeNode(node2) = &b {
|
||||
node.label.cmp(&node2.label)
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
}
|
||||
DbRecord::ValueState(state) => {
|
||||
if let DbRecord::ValueState(state2) = &b {
|
||||
match state.username.0.cmp(&state2.username.0) {
|
||||
Ordering::Equal => state.epoch.cmp(&state2.epoch),
|
||||
other => other,
|
||||
}
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
}
|
||||
_ => Ordering::Equal,
|
||||
});
|
||||
|
||||
// Execute value as bulk insert
|
||||
|
||||
// Create temp table
|
||||
let table: TempTable = storage_type.into();
|
||||
conn.simple_query(&table.create()).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to create temp table: {e}"))
|
||||
})?;
|
||||
|
||||
// Create bulk insert
|
||||
let table_name = &table.to_string();
|
||||
let mut bulk = conn.bulk_insert(table_name).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to start bulk insert: {e}"))
|
||||
})?;
|
||||
|
||||
for record in &record_group {
|
||||
let row = record.into_row()?;
|
||||
bulk.send(row).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
|
||||
})?;
|
||||
}
|
||||
|
||||
bulk.finalize().await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to finalize bulk insert: {e}"))
|
||||
})?;
|
||||
|
||||
// Set values from temp table to main table
|
||||
let sql = <DbRecord as MsSqlStorable>::set_batch_statement(&storage_type);
|
||||
conn.simple_query(&sql).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to execute batch set statement: {e}"))
|
||||
})?;
|
||||
|
||||
// Delete the temp table
|
||||
conn.simple_query(&table.drop())
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?;
|
||||
}
|
||||
|
||||
Ok::<(), StorageError>(())
|
||||
};
|
||||
|
||||
match result.await {
|
||||
Ok(_) => {
|
||||
conn.simple_query("COMMIT").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
conn.simple_query("ROLLBACK").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to roll back transaction: {e}"))
|
||||
})?;
|
||||
Err(StorageError::Other(format!(
|
||||
"Failed to batch set records: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
|
||||
todo!()
|
||||
let mut conn = self.get_connection().await?;
|
||||
let statement = DbRecord::get_statement::<St>(id)?;
|
||||
|
||||
let query_stream = conn
|
||||
.query(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?;
|
||||
|
||||
let row = query_stream
|
||||
.into_row()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to fetch row: {e}")))?;
|
||||
|
||||
if let Some(row) = row {
|
||||
DbRecord::from_row::<St>(&row)
|
||||
} else {
|
||||
Err(StorageError::NotFound(format!(
|
||||
"{:?} {:?} not found",
|
||||
St::data_type(),
|
||||
id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn batch_get<St: Storable>(&self, ids: &[St::StorageKey]) -> Result<Vec<DbRecord>, StorageError> {
|
||||
todo!()
|
||||
async fn batch_get<St: Storable>(
|
||||
&self,
|
||||
ids: &[St::StorageKey],
|
||||
) -> Result<Vec<DbRecord>, StorageError> {
|
||||
if ids.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let temp_table = TempTable::for_ids::<St>();
|
||||
if temp_table.can_create() {
|
||||
// AZKs does not support batch get, so we just return empty vec
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let create_temp_table = temp_table.create();
|
||||
let temp_table_name = &temp_table.to_string();
|
||||
|
||||
let mut conn = self.get_connection().await?;
|
||||
|
||||
// Begin a transaction
|
||||
conn.simple_query("BEGIN TRANSACTION")
|
||||
.await
|
||||
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
|
||||
|
||||
let result = async {
|
||||
// Use bulk_insert to insert all the ids into a temporary table
|
||||
conn.simple_query(&create_temp_table)
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?;
|
||||
let mut bulk = conn
|
||||
.bulk_insert(&temp_table_name)
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?;
|
||||
for row in DbRecord::get_batch_temp_table_rows::<St>(ids)? {
|
||||
bulk.send(row).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
|
||||
})?;
|
||||
}
|
||||
bulk.finalize()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?;
|
||||
|
||||
// Read rows matching the ids from the temporary table
|
||||
let get_sql = DbRecord::get_batch_statement::<St>();
|
||||
let query_stream = conn.simple_query(&get_sql).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to execute batch get query: {e}"))
|
||||
})?;
|
||||
let mut records = Vec::new();
|
||||
{
|
||||
let rows = query_stream
|
||||
.into_first_result()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?;
|
||||
for row in rows {
|
||||
let record = DbRecord::from_row::<St>(&row)?;
|
||||
records.push(record);
|
||||
}
|
||||
}
|
||||
|
||||
let drop_temp_table = temp_table.drop();
|
||||
conn.simple_query(&drop_temp_table)
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?;
|
||||
|
||||
Ok::<Vec<DbRecord>, StorageError>(vec![])
|
||||
};
|
||||
|
||||
// Commit or rollback the transaction
|
||||
match result.await {
|
||||
Ok(records) => {
|
||||
conn.simple_query("COMMIT").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
|
||||
})?;
|
||||
Ok(records)
|
||||
}
|
||||
Err(e) => {
|
||||
conn.simple_query("ROLLBACK").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to rollback transaction: {e}"))
|
||||
})?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user_data(&self, username: &AkdLabel) -> Result<types::KeyData, StorageError> {
|
||||
todo!()
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<types::KeyData, StorageError> {
|
||||
// Note: don't log raw_label or data as it may contain sensitive information, such as PII.
|
||||
|
||||
let result = async {
|
||||
let mut conn = self.get_connection().await?;
|
||||
|
||||
let statement = values::get_all(raw_label);
|
||||
|
||||
let query_stream = conn
|
||||
.query(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to execute query for label: {e}"))
|
||||
})?;
|
||||
let mut states = Vec::new();
|
||||
{
|
||||
let rows = query_stream.into_first_result().await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to fetch rows for label: {e}"))
|
||||
})?;
|
||||
for row in rows {
|
||||
let record = statement.parse(&row)?;
|
||||
states.push(record);
|
||||
}
|
||||
}
|
||||
let key_data = KeyData { states };
|
||||
|
||||
Ok::<KeyData, StorageError>(key_data)
|
||||
};
|
||||
|
||||
match result.await {
|
||||
Ok(data) => Ok(data),
|
||||
Err(e) => Err(StorageError::Other(format!(
|
||||
"Failed to get all data for label: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state(
|
||||
&self,
|
||||
username: &AkdLabel,
|
||||
raw_label: &AkdLabel,
|
||||
flag: types::ValueStateRetrievalFlag,
|
||||
) -> Result<types::ValueState, StorageError> {
|
||||
todo!()
|
||||
let statement = values::get_by_flag(raw_label, flag);
|
||||
let mut conn = self.get_connection().await?;
|
||||
let query_stream = conn
|
||||
.query(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?;
|
||||
let row = query_stream
|
||||
.into_row()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to fetch first result row: {e}")))?;
|
||||
if let Some(row) = row {
|
||||
statement.parse(&row)
|
||||
} else {
|
||||
Err(StorageError::NotFound(format!(
|
||||
"ValueState for label {:?} not found",
|
||||
raw_label
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
|
||||
// too restrictive for what we want, so generalize the name a bit.
|
||||
async fn get_user_state_versions(
|
||||
&self,
|
||||
usernames: &[AkdLabel],
|
||||
raw_labels: &[AkdLabel],
|
||||
flag: types::ValueStateRetrievalFlag,
|
||||
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
|
||||
todo!()
|
||||
let mut conn = self.get_connection().await?;
|
||||
|
||||
let temp_table = TempTable::RawLabelSearch;
|
||||
let create_temp_table = temp_table.create();
|
||||
let temp_table_name = &temp_table.to_string();
|
||||
|
||||
// Begin a transaction
|
||||
conn.simple_query("BEGIN TRANSACTION")
|
||||
.await
|
||||
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
|
||||
|
||||
let result = async {
|
||||
conn.simple_query(&create_temp_table)
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?;
|
||||
|
||||
// Use bulk_insert to insert all the raw_labels into a temporary table
|
||||
let mut bulk = conn
|
||||
.bulk_insert(&temp_table_name)
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?;
|
||||
for raw_label in raw_labels {
|
||||
let row = (raw_label.0.clone()).into_row();
|
||||
bulk.send(row).await.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
|
||||
})?;
|
||||
}
|
||||
bulk.finalize()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?;
|
||||
|
||||
// read rows matching the raw_labels from the temporary table
|
||||
let statement = values::get_versions_by_flag(&temp_table_name, flag);
|
||||
let query_stream = conn
|
||||
.query(statement.sql(), &statement.params())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StorageError::Other(format!("Failed to execute batch get query: {e}"))
|
||||
})?;
|
||||
|
||||
let mut results = HashMap::new();
|
||||
let rows = query_stream
|
||||
.into_first_result()
|
||||
.await
|
||||
.map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?;
|
||||
for row in rows {
|
||||
let label_version = statement.parse(&row)?;
|
||||
results.insert(
|
||||
label_version.label,
|
||||
(label_version.version, label_version.data),
|
||||
);
|
||||
}
|
||||
Ok(results)
|
||||
};
|
||||
|
||||
match result.await {
|
||||
Ok(results) => {
|
||||
conn.simple_query("COMMIT").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
|
||||
})?;
|
||||
Ok(results)
|
||||
}
|
||||
Err(e) => {
|
||||
conn.simple_query("ROLLBACK").await.map_err(|e| {
|
||||
StorageError::Transaction(format!("Failed to rollback transaction: {e}"))
|
||||
})?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
781
akd/crates/akd_storage/src/ms_sql_storable.rs
Normal file
781
akd/crates/akd_storage/src/ms_sql_storable.rs
Normal file
@@ -0,0 +1,781 @@
|
||||
use akd::{
|
||||
errors::StorageError,
|
||||
storage::{
|
||||
types::{DbRecord, StorageType, ValueState},
|
||||
Storable,
|
||||
},
|
||||
tree_node::TreeNodeWithPreviousValue,
|
||||
NodeLabel,
|
||||
};
|
||||
use ms_database::{ColumnData, IntoRow, Row, ToSql, TokenRow};
|
||||
|
||||
use crate::{migrations::{
|
||||
TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_VALUES
|
||||
}, temp_table::TempTable};
|
||||
use crate::sql_params::SqlParams;
|
||||
|
||||
const SELECT_AZKS_DATA: &'static [&str] = &["epoch", "num_nodes"];
|
||||
const SELECT_HISTORY_TREE_NODE_DATA: &'static [&str] = &[
|
||||
"label_len",
|
||||
"label_val",
|
||||
"last_epoch",
|
||||
"least_descendant_ep",
|
||||
"parent_label_len",
|
||||
"parent_label_val",
|
||||
"node_type",
|
||||
"left_child_len",
|
||||
"left_child_label_val",
|
||||
"right_child_len",
|
||||
"right_child_label_val",
|
||||
"hash",
|
||||
"p_last_epoch",
|
||||
"p_least_descendant_ep",
|
||||
"p_parent_label_len",
|
||||
"p_parent_label_val",
|
||||
"p_node_type",
|
||||
"p_left_child_len",
|
||||
"p_left_child_label_val",
|
||||
"p_right_child_len",
|
||||
"p_right_child_label_val",
|
||||
"p_hash",
|
||||
];
|
||||
const SELECT_LABEL_DATA: &'static [&str] = &[
|
||||
"raw_label",
|
||||
"epoch",
|
||||
"version",
|
||||
"node_label_val",
|
||||
"node_label_len",
|
||||
"data",
|
||||
];
|
||||
|
||||
pub(crate) struct Statement {
|
||||
sql: String,
|
||||
params: SqlParams,
|
||||
}
|
||||
|
||||
impl Statement {
|
||||
pub fn new(sql: String, params: SqlParams) -> Self {
|
||||
Self { sql, params }
|
||||
}
|
||||
|
||||
pub fn sql(&self) -> &str {
|
||||
&self.sql
|
||||
}
|
||||
|
||||
pub fn params(&self) -> Vec<&dyn ToSql> {
|
||||
self.params.values()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct QueryStatement<Out> {
|
||||
sql: String,
|
||||
params: SqlParams,
|
||||
parser: fn(&Row) -> Result<Out, StorageError>,
|
||||
}
|
||||
|
||||
impl<Out> QueryStatement<Out> {
|
||||
pub fn new(sql: String, params: SqlParams, parser: fn(&Row) -> Result<Out, StorageError>) -> Self {
|
||||
Self { sql, params, parser }
|
||||
}
|
||||
|
||||
pub fn sql(&self) -> &str {
|
||||
&self.sql
|
||||
}
|
||||
|
||||
pub fn params(&self) -> Vec<&dyn ToSql> {
|
||||
self.params.values()
|
||||
}
|
||||
|
||||
pub fn parse(&self, row: &Row) -> Result<Out, StorageError> {
|
||||
(self.parser)(row)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait MsSqlStorable {
|
||||
fn set_statement(&self) -> Result<Statement, StorageError>;
|
||||
|
||||
fn set_batch_statement(storage_type: &StorageType) -> String;
|
||||
|
||||
fn get_statement<St: Storable>(key: &St::StorageKey) -> Result<Statement, StorageError>;
|
||||
|
||||
fn get_batch_temp_table_rows<St: Storable>(
|
||||
key: &[St::StorageKey],
|
||||
) -> Result<Vec<TokenRow>, StorageError>;
|
||||
|
||||
fn get_batch_statement<St: Storable>() -> String;
|
||||
|
||||
fn from_row<St: Storable>(row: &Row) -> Result<Self, StorageError>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
fn into_row(&self) -> Result<TokenRow, StorageError>;
|
||||
}
|
||||
|
||||
impl MsSqlStorable for DbRecord {
|
||||
fn set_statement(&self) -> Result<Statement, StorageError> {
|
||||
match &self {
|
||||
DbRecord::Azks(azks) => {
|
||||
let mut params = SqlParams::new();
|
||||
params.add("key", Box::new(1u8)); // constant key
|
||||
// TODO: Fixup as conversions
|
||||
params.add("epoch", Box::new(azks.latest_epoch as i64));
|
||||
params.add("num_nodes", Box::new(azks.num_nodes as i64));
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_AZKS} AS target
|
||||
USING (SELECT {}) AS source
|
||||
ON target.[key] = source.[key]
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET {}
|
||||
target.[epoch] = source.[epoch],
|
||||
target.[num_nodes] = source.[num_nodes]
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT ({})
|
||||
VALUES ({});
|
||||
"#,
|
||||
params.keys_as_columns().join(", "),
|
||||
params.set_columns_equal("target.", "source.").join(", "),
|
||||
params.columns().join(", "),
|
||||
params.keys().join(", ")
|
||||
);
|
||||
|
||||
Ok(Statement::new(sql, params))
|
||||
}
|
||||
DbRecord::TreeNode(node) => {
|
||||
let mut params = SqlParams::new();
|
||||
params.add("label_len", Box::new(node.label.label_len as i32));
|
||||
params.add("label_val", Box::new(node.label.label_val.to_vec()));
|
||||
// Latest node values
|
||||
params.add("last_epoch", Box::new(node.latest_node.last_epoch as i64));
|
||||
params.add(
|
||||
"least_descendant_ep",
|
||||
Box::new(node.latest_node.min_descendant_epoch as i64),
|
||||
);
|
||||
params.add(
|
||||
"parent_label_len",
|
||||
Box::new(node.latest_node.parent.label_len as i32),
|
||||
);
|
||||
params.add(
|
||||
"parent_label_val",
|
||||
Box::new(node.latest_node.parent.label_val.to_vec()),
|
||||
);
|
||||
params.add("node_type", Box::new(node.latest_node.node_type as i16));
|
||||
params.add(
|
||||
"left_child_len",
|
||||
Box::new(node.latest_node.left_child.map(|lc| lc.label_len as i32)),
|
||||
);
|
||||
params.add(
|
||||
"left_child_val",
|
||||
Box::new(node.latest_node.left_child.map(|lc| lc.label_val.to_vec())),
|
||||
);
|
||||
params.add(
|
||||
"right_child_len",
|
||||
Box::new(node.latest_node.right_child.map(|rc| rc.label_len as i32)),
|
||||
);
|
||||
params.add(
|
||||
"right_child_val",
|
||||
Box::new(node.latest_node.right_child.map(|rc| rc.label_val.to_vec())),
|
||||
);
|
||||
params.add("[hash]", Box::new(node.latest_node.hash.0.to_vec()));
|
||||
// Previous node values
|
||||
params.add(
|
||||
"p_last_epoch",
|
||||
Box::new(node.previous_node.clone().map(|p| p.last_epoch as i64)),
|
||||
);
|
||||
params.add(
|
||||
"p_least_descendant_ep",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.map(|p| p.min_descendant_epoch as i64),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_parent_label_len",
|
||||
Box::new(node.previous_node.clone().map(|p| p.label.label_len as i32)),
|
||||
);
|
||||
params.add(
|
||||
"p_parent_label_val",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.map(|p| p.label.label_val.to_vec()),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_node_type",
|
||||
Box::new(node.previous_node.clone().map(|p| p.node_type as i16)),
|
||||
);
|
||||
params.add(
|
||||
"p_left_child_len",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.and_then(|p| p.left_child.map(|lc| lc.label_len as i32)),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_left_child_val",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.and_then(|p| p.left_child.map(|lc| lc.label_val.to_vec())),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_right_child_len",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.and_then(|p| p.right_child.map(|rc| rc.label_len as i32)),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_right_child_val",
|
||||
Box::new(
|
||||
node.previous_node
|
||||
.clone()
|
||||
.and_then(|p| p.right_child.map(|rc| rc.label_val.to_vec())),
|
||||
),
|
||||
);
|
||||
params.add(
|
||||
"p_hash",
|
||||
Box::new(node.previous_node.clone().map(|p| p.hash.0.to_vec())),
|
||||
);
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target
|
||||
USING (SELECT {}) AS source
|
||||
ON target.label_len = source.label_len AND target.label_val = source.label_val
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET {}
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT ({})
|
||||
VALUES ({});
|
||||
"#,
|
||||
params.keys_as_columns().join(", "),
|
||||
params
|
||||
.set_columns_equal_except(
|
||||
"target.",
|
||||
"source.",
|
||||
vec!["label_len", "label_val"]
|
||||
)
|
||||
.join(", "),
|
||||
params.columns().join(", "),
|
||||
params.keys().join(", "),
|
||||
);
|
||||
|
||||
Ok(Statement::new(sql, params))
|
||||
}
|
||||
DbRecord::ValueState(state) => {
|
||||
let mut params = SqlParams::new();
|
||||
params.add("raw_label", Box::new(state.get_id().0.clone()));
|
||||
// TODO: Fixup as conversions
|
||||
params.add("epoch", Box::new(state.epoch as i64));
|
||||
params.add("[version]", Box::new(state.version as i64));
|
||||
params.add("node_label_val", Box::new(state.label.label_val.to_vec()));
|
||||
params.add("node_label_len", Box::new(state.label.label_len as i64));
|
||||
params.add("[data]", Box::new(state.value.0.clone()));
|
||||
|
||||
// Note: raw_label & epoch are combined the primary key, so these are always new
|
||||
let sql = format!(
|
||||
r#"
|
||||
INSERT INTO dbo.{TABLE_VALUES} ({})
|
||||
VALUES ({});
|
||||
"#,
|
||||
params.columns().join(", "),
|
||||
params.keys().join(", "),
|
||||
);
|
||||
Ok(Statement::new(sql, params))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_batch_statement(storage_type: &StorageType) -> String {
|
||||
match storage_type {
|
||||
StorageType::Azks => format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_AZKS} AS target
|
||||
USING {} AS source
|
||||
ON target.[key] = source.[key]
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
target.[epoch] = source.[epoch],
|
||||
target.[num_nodes] = source.[num_nodes]
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT ([key], [epoch], [num_nodes])
|
||||
VALUES (source.[key], source.[epoch], source.[num_nodes]);
|
||||
"#,
|
||||
TempTable::Azks.to_string()
|
||||
),
|
||||
StorageType::TreeNode => format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target
|
||||
USING {} AS source
|
||||
ON target.label_len = source.label_len AND target.label_val = source.label_val
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
target.last_epoch = source.last_epoch,
|
||||
target.least_descendant_ep = source.least_descendant_ep,
|
||||
target.parent_label_len = source.parent_label_len,
|
||||
target.parent_label_val = source.parent_label_val,
|
||||
target.node_type = source.node_type,
|
||||
target.left_child_len = source.left_child_len,
|
||||
target.left_child_label_val = source.left_child_label_val,
|
||||
target.right_child_len = source.right_child_len,
|
||||
target.right_child_label_val = source.right_child_label_val,
|
||||
target.hash = source.hash,
|
||||
target.p_last_epoch = source.p_last_epoch,
|
||||
target.p_least_descendant_ep = source.p_least_descendant_ep,
|
||||
target.p_parent_label_len = source.p_parent_label_len,
|
||||
target.p_parent_label_val = source.p_parent_label_val,
|
||||
target.p_node_type = source.p_node_type,
|
||||
target.p_left_child_len = source.p_left_child_len,
|
||||
target.p_left_child_label_val = source.p_left_child_label_val,
|
||||
target.p_right_child_len = source.p_right_child_len,
|
||||
target.p_right_child_label_val = source.p_right_child_label_val,
|
||||
target.p_hash = source.p_hash
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (
|
||||
label_len
|
||||
, label_val
|
||||
, last_epoch
|
||||
, least_descendant_ep
|
||||
, parent_label_len
|
||||
, parent_label_val
|
||||
, node_type
|
||||
, left_child_len
|
||||
, left_child_label_val
|
||||
, right_child_len
|
||||
, right_child_label_val
|
||||
, hash
|
||||
, p_last_epoch
|
||||
, p_least_descendant_ep
|
||||
, p_parent_label_len
|
||||
, p_parent_label_val
|
||||
, p_node_type
|
||||
, p_left_child_len
|
||||
, p_left_child_label_val
|
||||
, p_right_child_len
|
||||
, p_right_child_label_val
|
||||
, p_hash
|
||||
)
|
||||
VALUES (
|
||||
source.label_len
|
||||
, source.label_val
|
||||
, source.last_epoch
|
||||
, source.least_descendant_ep
|
||||
, source.parent_label_len
|
||||
, source.parent_label_val
|
||||
, source.node_type
|
||||
, source.left_child_len
|
||||
, source.left_child_label_val
|
||||
, source.right_child_len
|
||||
, source.right_child_label_val
|
||||
, source.hash
|
||||
, source.p_last_epoch
|
||||
, source.p_least_descendant_ep
|
||||
, source.p_parent_label_len
|
||||
, source.p_parent_label_val
|
||||
, source.p_node_type
|
||||
, source.p_left_child_len
|
||||
, source.p_left_child_label_val
|
||||
, source.p_right_child_len
|
||||
, source.p_right_child_label_val
|
||||
, source.p_hash
|
||||
);
|
||||
"#,
|
||||
TempTable::HistoryTreeNodes.to_string()
|
||||
),
|
||||
StorageType::ValueState => format!(
|
||||
r#"
|
||||
MERGE INTO dbo.{TABLE_VALUES} AS target
|
||||
USING {} AS source
|
||||
ON target.raw_label = source.raw_label AND target.epoch = source.epoch
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
target.[version] = source.[version],
|
||||
target.node_label_val = source.node_label_val,
|
||||
target.node_label_len = source.node_label_len,
|
||||
target.[data] = source.[data]
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (raw_label, epoch, [version], node_label_val, node_label_len, [data])
|
||||
VALUES (source.raw_label, source.epoch, source.[version], source.node_label_val, source.node_label_len, source.[data]);
|
||||
"#,
|
||||
TempTable::Values.to_string()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_statement<St: Storable>(key: &St::StorageKey) -> Result<Statement, StorageError> {
|
||||
let mut params = SqlParams::new();
|
||||
let sql = match St::data_type() {
|
||||
StorageType::Azks => {
|
||||
format!(
|
||||
r#"
|
||||
SELECT {} from dbo.{} LIMIT 1;
|
||||
"#,
|
||||
SELECT_AZKS_DATA.join(", "),
|
||||
TABLE_AZKS
|
||||
)
|
||||
}
|
||||
StorageType::TreeNode => {
|
||||
let bin = St::get_full_binary_key_id(key);
|
||||
// These are constructed from a safe key, they should never fail
|
||||
let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin)
|
||||
.expect("Failed to decode key"); // TODO: should this be an error?
|
||||
|
||||
params.add("label_len", Box::new(key.0.label_len as i32));
|
||||
params.add("label_val", Box::new(key.0.label_val.to_vec()));
|
||||
|
||||
format!(
|
||||
r#"
|
||||
SELECT {} from dbo.{TABLE_HISTORY_TREE_NODES} WHERE [label_len] = {} AND [label_val] = {};
|
||||
"#,
|
||||
SELECT_HISTORY_TREE_NODE_DATA.join(", "),
|
||||
params.key_for("label_len").expect("key present"),
|
||||
params.key_for("label_val").expect("key present"),
|
||||
)
|
||||
}
|
||||
StorageType::ValueState => {
|
||||
let bin = St::get_full_binary_key_id(key);
|
||||
// These are constructed from a safe key, they should never fail
|
||||
let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error?
|
||||
|
||||
params.add("raw_label", Box::new(key.0.clone()));
|
||||
params.add("epoch", Box::new(key.1 as i64));
|
||||
format!(
|
||||
r#"
|
||||
SELECT {} from dbo.{TABLE_VALUES} WHERE [raw_label] = {} AND [epoch] = {};
|
||||
"#,
|
||||
SELECT_LABEL_DATA.join(", "),
|
||||
params.key_for("raw_label").expect("key present"),
|
||||
params.key_for("epoch").expect("key present"),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Statement::new(sql, params))
|
||||
}
|
||||
|
||||
fn get_batch_temp_table_rows<St: Storable>(
|
||||
key: &[St::StorageKey],
|
||||
) -> Result<Vec<TokenRow>, StorageError> {
|
||||
match St::data_type() {
|
||||
StorageType::Azks => Err(StorageError::Other(
|
||||
"Batch temp table rows not supported for Azks".to_string(),
|
||||
)),
|
||||
StorageType::TreeNode => {
|
||||
let mut rows = Vec::new();
|
||||
for k in key {
|
||||
let bin = St::get_full_binary_key_id(k);
|
||||
// These are constructed from a safe key, they should never fail
|
||||
let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin)
|
||||
.expect("Failed to decode key");
|
||||
|
||||
let row = (key.0.label_len as i32, key.0.label_val.to_vec()).into_row();
|
||||
rows.push(row);
|
||||
}
|
||||
Ok(rows)
|
||||
}
|
||||
StorageType::ValueState => {
|
||||
let mut rows = Vec::new();
|
||||
for k in key {
|
||||
let bin = St::get_full_binary_key_id(k);
|
||||
// These are constructed from a safe key, they should never fail
|
||||
let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error?
|
||||
|
||||
let row = (key.0.clone(), key.1 as i64).into_row();
|
||||
rows.push(row);
|
||||
}
|
||||
Ok(rows)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_batch_statement<St: Storable>() -> String {
|
||||
// Note: any changes to these columns need to be reflected in from_row below
|
||||
match St::data_type() {
|
||||
StorageType::Azks => panic!("Batch get not supported for Azks"),
|
||||
StorageType::TreeNode => format!(
|
||||
r#"
|
||||
SELECT {}
|
||||
FROM dbo.{TABLE_HISTORY_TREE_NODES} h
|
||||
INNER JOIN {} t
|
||||
ON h.label_len = t.label_len
|
||||
AND h.label_val = t.label_val;
|
||||
"#,
|
||||
SELECT_HISTORY_TREE_NODE_DATA
|
||||
.iter()
|
||||
.map(|s| format!("h.{s}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
TempTable::for_ids::<St>().to_string()
|
||||
),
|
||||
StorageType::ValueState => format!(
|
||||
r#"
|
||||
SELECT {}
|
||||
FROM dbo.{TABLE_VALUES} v
|
||||
INNER JOIN {} t
|
||||
ON v.raw_label = t.raw_label
|
||||
AND v.epoch = t.epoch;
|
||||
"#,
|
||||
SELECT_LABEL_DATA
|
||||
.iter()
|
||||
.map(|s| format!("v.{s}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
TempTable::for_ids::<St>().to_string()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_row<St: Storable>(row: &Row) -> Result<Self, StorageError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
match St::data_type() {
|
||||
// TODO: check this
|
||||
StorageType::Azks => {
|
||||
let epoch: i64 = row
|
||||
.get("epoch")
|
||||
.ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?;
|
||||
let num_nodes: i64 = row.get("num_nodes").ok_or_else(|| {
|
||||
StorageError::Other("num_nodes is NULL or missing".to_string())
|
||||
})?;
|
||||
|
||||
let azks = DbRecord::build_azks(epoch as u64, num_nodes as u64);
|
||||
Ok(DbRecord::Azks(azks))
|
||||
}
|
||||
StorageType::TreeNode => {
|
||||
let label_len: i32 = row.get("label_len").ok_or_else(|| {
|
||||
StorageError::Other("label_len is NULL or missing".to_string())
|
||||
})?;
|
||||
let label_val: &[u8] = row.get("label_val").ok_or_else(|| {
|
||||
StorageError::Other("label_val is NULL or missing".to_string())
|
||||
})?;
|
||||
let last_epoch: i64 = row.get("last_epoch").ok_or_else(|| {
|
||||
StorageError::Other("last_epoch is NULL or missing".to_string())
|
||||
})?;
|
||||
let least_descendant_ep: i64 = row.get("least_descendant_ep").ok_or_else(|| {
|
||||
StorageError::Other("least_descendant_ep is NULL or missing".to_string())
|
||||
})?;
|
||||
let parent_label_len: i32 = row.get("parent_label_len").ok_or_else(|| {
|
||||
StorageError::Other("parent_label_len is NULL or missing".to_string())
|
||||
})?;
|
||||
let parent_label_val: &[u8] = row.get("parent_label_val").ok_or_else(|| {
|
||||
StorageError::Other("parent_label_val is NULL or missing".to_string())
|
||||
})?;
|
||||
let node_type: i16 = row.get("node_type").ok_or_else(|| {
|
||||
StorageError::Other("node_type is NULL or missing".to_string())
|
||||
})?;
|
||||
let left_child_len: Option<i32> = row.get("left_child_len");
|
||||
let left_child_label_val: Option<&[u8]> = row.get("left_child_label_val");
|
||||
let right_child_len: Option<i32> = row.get("right_child_len");
|
||||
let right_child_label_val: Option<&[u8]> = row.get("right_child_len");
|
||||
let hash: &[u8] = row
|
||||
.get("hash")
|
||||
.ok_or_else(|| StorageError::Other("hash is NULL or missing".to_string()))?;
|
||||
let p_last_epoch: Option<i64> = row.get("p_last_epoch");
|
||||
let p_least_descendant_ep: Option<i64> = row.get("p_least_descendant_ep");
|
||||
let p_parent_label_len: Option<i32> = row.get("p_parent_label_len");
|
||||
let p_parent_label_val: Option<&[u8]> = row.get("p_parent_label_val");
|
||||
let p_node_type: Option<i16> = row.get("p_node_type");
|
||||
let p_left_child_len: Option<i32> = row.get("p_left_child_len");
|
||||
let p_left_child_label_val: Option<&[u8]> = row.get("p_left_child_label_val");
|
||||
let p_right_child_len: Option<i32> = row.get("p_right_child_len");
|
||||
let p_right_child_label_val: Option<&[u8]> = row.get("p_right_child_label_val");
|
||||
let p_hash: Option<&[u8]> = row.get("p_hash");
|
||||
|
||||
// Make child nodes
|
||||
fn optional_child_label(
|
||||
child_val: Option<&[u8]>,
|
||||
child_len: Option<i32>,
|
||||
) -> Result<Option<NodeLabel>, StorageError> {
|
||||
match (child_val, child_len) {
|
||||
(Some(val), Some(len)) => {
|
||||
let val_vec: Vec<u8> = val.to_vec().try_into().map_err(|_| {
|
||||
StorageError::Other("child_val has incorrect length".to_string())
|
||||
})?;
|
||||
Ok(Some(NodeLabel::new(
|
||||
val_vec.try_into().map_err(|_| {
|
||||
StorageError::Other(
|
||||
"child_val has incorrect length".to_string(),
|
||||
)
|
||||
})?,
|
||||
len as u32,
|
||||
)))
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
let left_child = optional_child_label(left_child_label_val, left_child_len)?;
|
||||
let right_child = optional_child_label(right_child_label_val, right_child_len)?;
|
||||
let p_left_child = optional_child_label(p_left_child_label_val, p_left_child_len)?;
|
||||
let p_right_child =
|
||||
optional_child_label(p_right_child_label_val, p_right_child_len)?;
|
||||
|
||||
let massaged_p_parent_label_val: Option<[u8; 32]> = match p_parent_label_val {
|
||||
Some(v) => Some(v.to_vec().try_into().map_err(|_| {
|
||||
StorageError::Other("p_parent_label_val has incorrect length".to_string())
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
let massaged_hash: akd::Digest = akd::hash::try_parse_digest(&hash.to_vec())
|
||||
.map_err(|_| StorageError::Other("hash has incorrect length".to_string()))?;
|
||||
let massaged_p_hash: Option<akd::Digest> = match p_hash {
|
||||
Some(v) => Some(akd::hash::try_parse_digest(&v.to_vec()).map_err(|_| {
|
||||
StorageError::Other("p_hash has incorrect length".to_string())
|
||||
})?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let node = DbRecord::build_tree_node_with_previous_value(
|
||||
label_val.try_into().map_err(|_| {
|
||||
StorageError::Other("label_val has incorrect length".to_string())
|
||||
})?,
|
||||
label_len as u32,
|
||||
last_epoch as u64,
|
||||
least_descendant_ep as u64,
|
||||
parent_label_val.try_into().map_err(|_| {
|
||||
StorageError::Other("parent_label_val has incorrect length".to_string())
|
||||
})?,
|
||||
parent_label_len as u32,
|
||||
node_type as u8,
|
||||
left_child,
|
||||
right_child,
|
||||
massaged_hash,
|
||||
p_last_epoch.map(|v| v as u64),
|
||||
p_least_descendant_ep.map(|v| v as u64),
|
||||
massaged_p_parent_label_val,
|
||||
p_parent_label_len.map(|v| v as u32),
|
||||
p_node_type.map(|v| v as u8),
|
||||
p_left_child,
|
||||
p_right_child,
|
||||
massaged_p_hash,
|
||||
);
|
||||
|
||||
Ok(DbRecord::TreeNode(node))
|
||||
}
|
||||
StorageType::ValueState => Ok(DbRecord::ValueState(crate::tables::values::from_row(row)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_row(&self) -> Result<TokenRow, StorageError> {
|
||||
match &self {
|
||||
DbRecord::Azks(azks) => {
|
||||
let row = (
|
||||
1u8, // constant key
|
||||
azks.latest_epoch as i64,
|
||||
azks.num_nodes as i64,
|
||||
)
|
||||
.into_row();
|
||||
Ok(row)
|
||||
}
|
||||
DbRecord::TreeNode(node) => {
|
||||
let mut row = TokenRow::new();
|
||||
row.push(ColumnData::I32(Some(node.label.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(
|
||||
node.label.label_val.to_vec().into(),
|
||||
)));
|
||||
// Latest node values
|
||||
row.push(ColumnData::I64(Some(node.latest_node.last_epoch as i64)));
|
||||
row.push(ColumnData::I64(Some(
|
||||
node.latest_node.min_descendant_epoch as i64,
|
||||
)));
|
||||
row.push(ColumnData::I32(Some(
|
||||
node.latest_node.parent.label_len as i32,
|
||||
)));
|
||||
row.push(ColumnData::Binary(Some(
|
||||
node.latest_node.parent.label_val.to_vec().into(),
|
||||
)));
|
||||
row.push(ColumnData::I16(Some(node.latest_node.node_type as i16)));
|
||||
match &node.latest_node.left_child {
|
||||
Some(lc) => {
|
||||
row.push(ColumnData::I32(Some(lc.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into())));
|
||||
}
|
||||
None => {
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
}
|
||||
}
|
||||
match &node.latest_node.right_child {
|
||||
Some(rc) => {
|
||||
row.push(ColumnData::I32(Some(rc.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into())));
|
||||
}
|
||||
None => {
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
}
|
||||
}
|
||||
row.push(ColumnData::Binary(Some(
|
||||
node.latest_node.hash.0.to_vec().into(),
|
||||
)));
|
||||
// Previous node values
|
||||
match &node.previous_node {
|
||||
Some(p) => {
|
||||
row.push(ColumnData::I64(Some(p.last_epoch as i64)));
|
||||
row.push(ColumnData::I64(Some(p.min_descendant_epoch as i64)));
|
||||
row.push(ColumnData::I32(Some(p.label.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(p.label.label_val.to_vec().into())));
|
||||
row.push(ColumnData::I16(Some(p.node_type as i16)));
|
||||
match &p.left_child {
|
||||
Some(lc) => {
|
||||
row.push(ColumnData::I32(Some(lc.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into())));
|
||||
}
|
||||
None => {
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
}
|
||||
}
|
||||
match &p.right_child {
|
||||
Some(rc) => {
|
||||
row.push(ColumnData::I32(Some(rc.label_len as i32)));
|
||||
row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into())));
|
||||
}
|
||||
None => {
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
}
|
||||
}
|
||||
row.push(ColumnData::Binary(Some(p.hash.0.to_vec().into())));
|
||||
}
|
||||
None => {
|
||||
// Node Values
|
||||
row.push(ColumnData::I64(None));
|
||||
row.push(ColumnData::I64(None));
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
row.push(ColumnData::I16(None));
|
||||
// Left child
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
// Right child
|
||||
row.push(ColumnData::I32(None));
|
||||
row.push(ColumnData::Binary(None));
|
||||
// Hash
|
||||
row.push(ColumnData::Binary(None));
|
||||
}
|
||||
}
|
||||
Ok(row)
|
||||
}
|
||||
DbRecord::ValueState(state) => {
|
||||
let row = (
|
||||
state.get_id().0.clone(),
|
||||
state.epoch as i64,
|
||||
state.version as i64,
|
||||
state.label.label_val.to_vec(),
|
||||
state.label.label_len as i64,
|
||||
state.value.0.clone(),
|
||||
)
|
||||
.into_row();
|
||||
Ok(row)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
177
akd/crates/akd_storage/src/sql_params.rs
Normal file
177
akd/crates/akd_storage/src/sql_params.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use ms_database::ToSql;
|
||||
|
||||
pub struct SqlParam {
|
||||
/// The parameter key (e.g., "@P1", "@P2")
|
||||
pub key: String,
|
||||
/// The column name this parameter maps to
|
||||
column: String,
|
||||
pub data: Box<dyn ToSql>,
|
||||
}
|
||||
|
||||
impl SqlParam {
|
||||
fn column(&self) -> String {
|
||||
SqlParam::wrap_in_brackets(&self.column)
|
||||
}
|
||||
|
||||
fn wrap_in_brackets(s: &str) -> String {
|
||||
// ensure column names are wrapped in brackets for SQL Server
|
||||
let trimmed = s.trim();
|
||||
let starts_with_bracket = trimmed.starts_with('[');
|
||||
let ends_with_bracket = trimmed.ends_with(']');
|
||||
|
||||
match (starts_with_bracket, ends_with_bracket) {
|
||||
(true, true) => trimmed.to_string(),
|
||||
(true, false) => format!("{}]", trimmed),
|
||||
(false, true) => format!("[{}", trimmed),
|
||||
(false, false) => format!("[{}]", trimmed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SqlParams {
|
||||
params: Vec<Box<SqlParam>>,
|
||||
}
|
||||
|
||||
impl SqlParams {
|
||||
pub fn new() -> Self {
|
||||
Self { params: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn add(&mut self, column: impl Into<String>, value: Box<dyn ToSql>) {
|
||||
self.params.push(Box::new(SqlParam {
|
||||
key: format!("@P{}", self.params.len() + 1),
|
||||
column: column.into(),
|
||||
data: value,
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn keys(&self) -> Vec<String> {
|
||||
self.params.iter().map(|p| p.key.clone()).collect()
|
||||
}
|
||||
|
||||
pub fn keys_as_columns(&self) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.map(|p| format!("{} AS {}", p.key, p.column))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn keys_except_columns(&self, excludes: Vec<&str>) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.filter(|p| !excludes.contains(&p.column.as_str()))
|
||||
.map(|p| p.key.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn key_for(&self, column: &str) -> Option<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.find(|p| p.column == column)
|
||||
.map(|p| p.key.clone())
|
||||
}
|
||||
|
||||
pub fn columns(&self) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.map(|p| p.column())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn columns_except(&self, excludes: Vec<&str>) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.filter(|p| !excludes.contains(&p.column.as_str()))
|
||||
.map(|p| p.column())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn columns_prefix_with(&self, prefix: &str) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.map(|p| format!("{}{}", prefix, p.column()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_columns_equal(&self, assign_prefix: &str, source_prefix: &str) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_columns_equal_except(
|
||||
&self,
|
||||
assign_prefix: &str,
|
||||
source_prefix: &str,
|
||||
excludes: Vec<&str>,
|
||||
) -> Vec<String> {
|
||||
self.params
|
||||
.iter()
|
||||
.filter(|p| !excludes.contains(&p.column.as_str()))
|
||||
.map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn values(&self) -> Vec<&(dyn ToSql)> {
|
||||
self.params
|
||||
.iter()
|
||||
.map(|b| b.data.as_ref() as &(dyn ToSql))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VecStringBuilder<'a> {
|
||||
strings: Vec<String>,
|
||||
ops: Vec<StringBuilderOperation<'a>>,
|
||||
}
|
||||
|
||||
enum StringBuilderOperation<'a> {
|
||||
StringOperation(Box<dyn Fn(String) -> String + 'a>),
|
||||
VectorOperation(Box<dyn Fn(Vec<String>) -> Vec<String> + 'a>),
|
||||
}
|
||||
|
||||
impl<'a> VecStringBuilder<'a> {
|
||||
pub fn new(strings: Vec<String>) -> Self {
|
||||
Self {
|
||||
strings,
|
||||
ops: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map<F>(&mut self, op: F)
|
||||
where
|
||||
F: Fn(String) -> String + 'static,
|
||||
{
|
||||
self.ops.push(StringBuilderOperation::StringOperation(Box::new(op)));
|
||||
}
|
||||
|
||||
pub fn build(self) -> Vec<String> {
|
||||
let mut result = self.strings;
|
||||
for op in self.ops {
|
||||
match op {
|
||||
StringBuilderOperation::StringOperation(f) => {
|
||||
result = result.into_iter().map(f.as_ref()).collect();
|
||||
}
|
||||
StringBuilderOperation::VectorOperation(f) => {
|
||||
result = f(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn join(self, sep: &str) -> String {
|
||||
self.build().join(sep)
|
||||
}
|
||||
|
||||
pub fn except(&mut self, excludes: Vec<&'a str>) {
|
||||
self.ops.push(StringBuilderOperation::VectorOperation(Box::new(
|
||||
move |vec: Vec<String>| {
|
||||
vec.into_iter()
|
||||
.filter(|s| !excludes.contains(&s.as_str()))
|
||||
.collect()
|
||||
},
|
||||
)));
|
||||
}
|
||||
}
|
||||
1
akd/crates/akd_storage/src/tables/mod.rs
Normal file
1
akd/crates/akd_storage/src/tables/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub(crate) mod values;
|
||||
190
akd/crates/akd_storage/src/tables/values.rs
Normal file
190
akd/crates/akd_storage/src/tables/values.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use akd::{
|
||||
errors::StorageError,
|
||||
storage::types::{DbRecord, ValueState, ValueStateRetrievalFlag},
|
||||
AkdLabel, AkdValue,
|
||||
};
|
||||
use ms_database::Row;
|
||||
|
||||
use crate::{migrations::TABLE_VALUES, ms_sql_storable::{QueryStatement, Statement}, sql_params::SqlParams};
|
||||
|
||||
pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
|
||||
let mut params = SqlParams::new();
|
||||
// the raw vector is the key for value storage
|
||||
params.add("raw_label", Box::new(raw_label.0.clone()));
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT raw_label, epoch, version, node_label_val, node_label_len, data
|
||||
FROM {}
|
||||
WHERE raw_label = {}
|
||||
"#,
|
||||
TABLE_VALUES, params
|
||||
.key_for("raw_label")
|
||||
.expect("raw_label was added to the params list")
|
||||
);
|
||||
QueryStatement::new(sql, params, from_row)
|
||||
}
|
||||
|
||||
pub fn get_by_flag(raw_label: &AkdLabel, flag: ValueStateRetrievalFlag) -> QueryStatement<ValueState> {
|
||||
let mut params = SqlParams::new();
|
||||
params.add("raw_label", Box::new(raw_label.0.clone()));
|
||||
|
||||
match flag {
|
||||
ValueStateRetrievalFlag::SpecificEpoch(epoch)
|
||||
| ValueStateRetrievalFlag::LeqEpoch(epoch) => params.add("epoch", Box::new(epoch as i64)),
|
||||
ValueStateRetrievalFlag::SpecificVersion(version) => {
|
||||
params.add("version", Box::new(version as i64))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let mut sql = format!(
|
||||
r#"
|
||||
SELECT TOP(1) raw_label, epoch, version, node_label_val, node_label_len, data
|
||||
FROM {}
|
||||
WHERE raw_label = {}
|
||||
"#,
|
||||
TABLE_VALUES,
|
||||
params
|
||||
.key_for("raw_label")
|
||||
.expect("raw_label was added to the params list")
|
||||
);
|
||||
|
||||
match flag {
|
||||
ValueStateRetrievalFlag::SpecificEpoch(_) => {
|
||||
sql.push_str(&format!(
|
||||
" AND epoch = {}",
|
||||
¶ms
|
||||
.key_for("epoch")
|
||||
.expect("epoch was added to the params list")
|
||||
));
|
||||
}
|
||||
ValueStateRetrievalFlag::SpecificVersion(_) => {
|
||||
sql.push_str(&format!(
|
||||
" AND version = {}",
|
||||
¶ms
|
||||
.key_for("version")
|
||||
.expect("version was added to the params list")
|
||||
));
|
||||
}
|
||||
ValueStateRetrievalFlag::MaxEpoch => {
|
||||
sql.push_str(" ORDER BY epoch DESC ");
|
||||
}
|
||||
ValueStateRetrievalFlag::MinEpoch => {
|
||||
sql.push_str(" ORDER BY epoch ASC ");
|
||||
}
|
||||
ValueStateRetrievalFlag::LeqEpoch(_) => {
|
||||
sql.push_str(&format!(
|
||||
r#"
|
||||
AND epoch <= {}
|
||||
ORDER BY epoch DESC
|
||||
"#,
|
||||
¶ms
|
||||
.key_for("epoch")
|
||||
.expect("epoch was added to the params list")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
QueryStatement::new(sql, params, from_row)
|
||||
}
|
||||
|
||||
pub fn get_versions_by_flag(
|
||||
temp_table_name: &str,
|
||||
flag: ValueStateRetrievalFlag,
|
||||
) -> QueryStatement<LabelVersion> {
|
||||
let mut params = SqlParams::new();
|
||||
|
||||
let (filter, epoch_col) = match flag {
|
||||
ValueStateRetrievalFlag::SpecificVersion(version) => {
|
||||
params.add("version", Box::new(version as i64));
|
||||
(format!("WHERE tmp.version = {}", params.key_for("version").expect("version was added to the params list")), "tmp.epoch")
|
||||
}
|
||||
ValueStateRetrievalFlag::SpecificEpoch(epoch) => {
|
||||
params.add("epoch", Box::new(epoch as i64));
|
||||
(format!("WHERE tmp.epoch = {}", params.key_for("epoch").expect("epoch was added to the params list")), "tmp.epoch")
|
||||
}
|
||||
ValueStateRetrievalFlag::LeqEpoch(epoch) => {
|
||||
params.add("epoch", Box::new(epoch as i64));
|
||||
(format!("WHERE tmp.epoch <= {}", params.key_for("epoch").expect("epoch was added to the params list")), "MAX(tmp.epoch)")
|
||||
}
|
||||
ValueStateRetrievalFlag::MaxEpoch => ("".to_string(), "MAX(tmp.epoch)"),
|
||||
ValueStateRetrievalFlag::MinEpoch => ("".to_string(), "MIN(tmp.epoch)"),
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT t.raw_label, t.version, t.data
|
||||
FROM {TABLE_VALUES} t
|
||||
INNER JOIN (
|
||||
SELECT tmp.raw_label as raw_label, {} as epoch
|
||||
FROM {TABLE_VALUES} tmp
|
||||
INNER JOIN {temp_table_name} s ON s.raw_label = tmp.raw_label
|
||||
{}
|
||||
GROUP BY tmp.raw_label
|
||||
) epochs on epochs.raw_label = t.raw_label AND epochs.epoch = t.epoch
|
||||
"#,
|
||||
epoch_col,
|
||||
filter,
|
||||
);
|
||||
|
||||
QueryStatement::new(sql, params, version_from_row)
|
||||
}
|
||||
|
||||
pub(crate) fn from_row(row: &Row) -> Result<ValueState, StorageError> {
|
||||
let raw_label: &[u8] = row
|
||||
.get("raw_label")
|
||||
.ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?;
|
||||
let epoch: i64 = row
|
||||
.get("epoch")
|
||||
.ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?;
|
||||
let version: i64 = row
|
||||
.get("version")
|
||||
.ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?;
|
||||
let node_label_val: &[u8] = row
|
||||
.get("node_label_val")
|
||||
.ok_or_else(|| StorageError::Other("node_label_val is NULL or missing".to_string()))?;
|
||||
let node_label_len: i64 = row
|
||||
.get("node_label_len")
|
||||
.ok_or_else(|| StorageError::Other("node_label_len is NULL or missing".to_string()))?;
|
||||
let data: &[u8] = row
|
||||
.get("data")
|
||||
.ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?;
|
||||
|
||||
let state = DbRecord::build_user_state(
|
||||
raw_label.to_vec(),
|
||||
data.to_vec(),
|
||||
version as u64,
|
||||
node_label_len as u32,
|
||||
node_label_val
|
||||
.to_vec()
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::Other("node_label_val has incorrect length".to_string()))?,
|
||||
epoch as u64,
|
||||
);
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub(crate) struct LabelVersion {
|
||||
pub label: AkdLabel,
|
||||
pub version: u64,
|
||||
pub data: AkdValue,
|
||||
}
|
||||
|
||||
fn version_from_row(row: &Row) -> Result<LabelVersion, StorageError> {
|
||||
let raw_label: &[u8] = row
|
||||
.get("raw_label")
|
||||
.ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?;
|
||||
let version: i64 = row
|
||||
.get("version")
|
||||
.ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?;
|
||||
let data: &[u8] = row
|
||||
.get("data")
|
||||
.ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?;
|
||||
|
||||
Ok(LabelVersion {
|
||||
label: AkdLabel(raw_label.to_vec()),
|
||||
version: version as u64,
|
||||
data: AkdValue(data.to_vec()),
|
||||
})
|
||||
}
|
||||
@@ -7,13 +7,10 @@ pub(crate) enum TempTable {
|
||||
Azks,
|
||||
HistoryTreeNodes,
|
||||
Values,
|
||||
RawLabelSearch,
|
||||
}
|
||||
|
||||
impl TempTable {
|
||||
pub fn for_ids_for(storage_type: &StorageType) -> Self {
|
||||
TempTable::Ids(storage_type.clone())
|
||||
}
|
||||
|
||||
pub fn for_ids<St: Storable>() -> Self {
|
||||
TempTable::Ids(St::data_type())
|
||||
}
|
||||
@@ -111,6 +108,15 @@ impl TempTable {
|
||||
TEMP_IDS_TABLE
|
||||
),
|
||||
}
|
||||
TempTable::RawLabelSearch => format!(
|
||||
r#"
|
||||
CREATE TABLE {} (
|
||||
raw_label VARBINARY(256) NOT NULL,
|
||||
PRIMARY KEY (raw_label)
|
||||
);
|
||||
"#,
|
||||
TEMP_SEARCH_LABELS_TABLE
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,6 +128,7 @@ impl ToString for TempTable {
|
||||
TempTable::Azks => TEMP_AZKS_TABLE.to_string(),
|
||||
TempTable::HistoryTreeNodes => TEMP_HISTORY_TREE_NODES_TABLE.to_string(),
|
||||
TempTable::Values => TEMP_VALUES_TABLE.to_string(),
|
||||
TempTable::RawLabelSearch => TEMP_SEARCH_LABELS_TABLE.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,6 +144,7 @@ impl From<StorageType> for TempTable {
|
||||
}
|
||||
|
||||
pub(crate) const TEMP_IDS_TABLE: &str = "#akd_temp_ids";
|
||||
pub(crate) const TEMP_SEARCH_LABELS_TABLE: &str = "#akd_temp_search_labels";
|
||||
pub(crate) const TEMP_AZKS_TABLE: &str = "#akd_temp_azks";
|
||||
pub(crate) const TEMP_HISTORY_TREE_NODES_TABLE: &str = "#akd_temp_history_tree_nodes";
|
||||
pub(crate) const TEMP_VALUES_TABLE: &str = "#akd_temp_values";
|
||||
|
||||
@@ -6,7 +6,7 @@ pub use pool::ConnectionManager as MsSqlConnectionManager;
|
||||
pub use pool::{OnConnectError};
|
||||
|
||||
// re-expose tiberius types for convenience
|
||||
pub use tiberius::{error::Error as MsDbError, Column, Row, ToSql};
|
||||
pub use tiberius::{error::Error as MsDbError, Column, Row, FromSql, ToSql, IntoRow, TokenRow, ColumnData};
|
||||
|
||||
// re-expose bb8 types for convenience
|
||||
pub type Pool = bb8::Pool<MsSqlConnectionManager>;
|
||||
|
||||
Reference in New Issue
Block a user