From bbd1a230a6a96337329c8a4b5986e2d879ec8fe7 Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Tue, 21 Oct 2025 13:16:05 -0700 Subject: [PATCH] First complete implementation of `Database` trait for sql server --- akd/crates/akd_storage/src/lib.rs | 1 + akd/crates/akd_storage/src/ms_sql.rs | 469 ++++++++++- akd/crates/akd_storage/src/ms_sql_storable.rs | 781 ++++++++++++++++++ akd/crates/akd_storage/src/sql_params.rs | 177 ++++ akd/crates/akd_storage/src/tables/mod.rs | 1 + akd/crates/akd_storage/src/tables/values.rs | 190 +++++ akd/crates/akd_storage/src/temp_table.rs | 16 +- akd/crates/ms_database/src/lib.rs | 2 +- 8 files changed, 1616 insertions(+), 21 deletions(-) create mode 100644 akd/crates/akd_storage/src/ms_sql_storable.rs create mode 100644 akd/crates/akd_storage/src/sql_params.rs create mode 100644 akd/crates/akd_storage/src/tables/mod.rs create mode 100644 akd/crates/akd_storage/src/tables/values.rs diff --git a/akd/crates/akd_storage/src/lib.rs b/akd/crates/akd_storage/src/lib.rs index fa23c7505a..6d33dfea5e 100644 --- a/akd/crates/akd_storage/src/lib.rs +++ b/akd/crates/akd_storage/src/lib.rs @@ -2,4 +2,5 @@ mod migrations; mod sql_params; mod ms_sql_storable; pub mod ms_sql; +mod tables; mod temp_table; diff --git a/akd/crates/akd_storage/src/ms_sql.rs b/akd/crates/akd_storage/src/ms_sql.rs index 419d57f04b..9260ceafad 100644 --- a/akd/crates/akd_storage/src/ms_sql.rs +++ b/akd/crates/akd_storage/src/ms_sql.rs @@ -1,48 +1,485 @@ -use std::collections::HashMap; +use std::{cmp::Ordering, collections::HashMap, sync::Arc}; -use akd::{errors::StorageError, storage::{types::{self, DbRecord}, Database, DbSetState, Storable}, AkdLabel, AkdValue}; +use akd::{ + errors::StorageError, + storage::{ + types::{self, DbRecord, KeyData, StorageType}, + Database, DbSetState, Storable, + }, + AkdLabel, AkdValue, +}; use async_trait::async_trait; -use ms_database::ConnectionManager; +use ms_database::{IntoRow, MsSqlConnectionManager, Pool, PooledConnection}; + +use crate::{ + migrations::MIGRATIONS, + ms_sql_storable::{MsSqlStorable, Statement}, + tables::values, + temp_table::TempTable, +}; + +const DEFAULT_POOL_SIZE: u32 = 100; + +pub struct MsSqlBuilder { + connection_string: String, + pool_size: Option, +} + +impl MsSqlBuilder { + /// Create a new [MsSqlBuilder] with the given connection string. + pub fn new(connection_string: String) -> Self { + Self { + connection_string, + pool_size: None, + } + } + + /// Set the connection pool size. Default is given by [DEFAULT_POOL_SIZE]. + pub fn pool_size(mut self, pool_size: u32) -> Self { + self.pool_size = Some(pool_size); + self + } + + /// Build the [MsSql] instance. + pub async fn build(self) -> Result { + let pool_size = self.pool_size.unwrap_or(DEFAULT_POOL_SIZE); + + MsSql::new(self.connection_string, pool_size).await + } +} pub struct MsSql { - connection_manager: ConnectionManager, + pool: Arc, +} + +impl MsSql { + pub fn builder(connection_string: String) -> MsSqlBuilder { + MsSqlBuilder::new(connection_string) + } + + pub async fn new(connection_string: String, pool_size: u32) -> Result { + let connection_manager = MsSqlConnectionManager::new(connection_string); + let pool = Pool::builder() + .max_size(pool_size) + .build(connection_manager) + .await + .map_err(|e| StorageError::Connection(format!("Failed to create DB pool: {}", e)))?; + + Ok(Self { + pool: Arc::new(pool), + }) + } + + pub async fn migrate(&self) -> Result<(), StorageError> { + let mut conn = self.pool.get().await.map_err(|e| { + StorageError::Connection(format!("Failed to get DB connection for migrations: {}", e)) + })?; + + ms_database::run_pending_migrations(&mut conn, MIGRATIONS) + .await + .map_err(|e| StorageError::Connection(format!("Failed to run migrations: {}", e)))?; + Ok(()) + } + + async fn get_connection(&self) -> Result, StorageError> { + self.pool + .get() + .await + .map_err(|e| StorageError::Connection(format!("Failed to get DB connection: {}", e))) + } + + async fn execute_statement(&self, statement: &Statement) -> Result<(), StorageError> { + let mut conn = self.get_connection().await?; + self.execute_statement_on_connection(statement, &mut conn) + .await + } + + async fn execute_statement_on_connection( + &self, + statement: &Statement, + conn: &mut PooledConnection<'_>, + ) -> Result<(), StorageError> { + conn.execute(statement.sql(), &statement.params()) + .await + .map_err(|e| StorageError::Other(format!("Failed to execute statement: {}", e)))?; + Ok(()) + } } #[async_trait] impl Database for MsSql { async fn set(&self, record: DbRecord) -> Result<(), StorageError> { - todo!() + let statement = record.set_statement()?; + self.execute_statement(&statement).await } - async fn batch_set(&self, records: Vec, state: DbSetState) -> Result<(), StorageError> { - todo!() + async fn batch_set( + &self, + records: Vec, + _state: DbSetState, // TODO: unused in mysql example, but may be needed later + ) -> Result<(), StorageError> { + // Generate groups by type + let mut groups = HashMap::new(); + for record in records { + match &record { + DbRecord::Azks(_) => groups + .entry(StorageType::Azks) + .or_insert_with(Vec::new) + .push(record), + DbRecord::TreeNode(_) => groups + .entry(StorageType::TreeNode) + .or_insert_with(Vec::new) + .push(record), + DbRecord::ValueState(_) => groups + .entry(StorageType::ValueState) + .or_insert_with(Vec::new) + .push(record), + } + } + + // Execute each group in batches + let mut conn = self.get_connection().await?; + // Start transaction + conn.simple_query("BEGIN TRANSACTION") + .await + .map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?; + let result = async { + for (storage_type, mut record_group) in groups.into_iter() { + if record_group.is_empty() { + continue; + } + + // Sort the records to match db-layer sorting + record_group.sort_by(|a, b| match &a { + DbRecord::TreeNode(node) => { + if let DbRecord::TreeNode(node2) = &b { + node.label.cmp(&node2.label) + } else { + Ordering::Equal + } + } + DbRecord::ValueState(state) => { + if let DbRecord::ValueState(state2) = &b { + match state.username.0.cmp(&state2.username.0) { + Ordering::Equal => state.epoch.cmp(&state2.epoch), + other => other, + } + } else { + Ordering::Equal + } + } + _ => Ordering::Equal, + }); + + // Execute value as bulk insert + + // Create temp table + let table: TempTable = storage_type.into(); + conn.simple_query(&table.create()).await.map_err(|e| { + StorageError::Other(format!("Failed to create temp table: {e}")) + })?; + + // Create bulk insert + let table_name = &table.to_string(); + let mut bulk = conn.bulk_insert(table_name).await.map_err(|e| { + StorageError::Other(format!("Failed to start bulk insert: {e}")) + })?; + + for record in &record_group { + let row = record.into_row()?; + bulk.send(row).await.map_err(|e| { + StorageError::Other(format!("Failed to add row to bulk insert: {e}")) + })?; + } + + bulk.finalize().await.map_err(|e| { + StorageError::Other(format!("Failed to finalize bulk insert: {e}")) + })?; + + // Set values from temp table to main table + let sql = ::set_batch_statement(&storage_type); + conn.simple_query(&sql).await.map_err(|e| { + StorageError::Other(format!("Failed to execute batch set statement: {e}")) + })?; + + // Delete the temp table + conn.simple_query(&table.drop()) + .await + .map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?; + } + + Ok::<(), StorageError>(()) + }; + + match result.await { + Ok(_) => { + conn.simple_query("COMMIT").await.map_err(|e| { + StorageError::Transaction(format!("Failed to commit transaction: {e}")) + })?; + Ok(()) + } + Err(e) => { + conn.simple_query("ROLLBACK").await.map_err(|e| { + StorageError::Transaction(format!("Failed to roll back transaction: {e}")) + })?; + Err(StorageError::Other(format!( + "Failed to batch set records: {}", + e + ))) + } + } } async fn get(&self, id: &St::StorageKey) -> Result { - todo!() + let mut conn = self.get_connection().await?; + let statement = DbRecord::get_statement::(id)?; + + let query_stream = conn + .query(statement.sql(), &statement.params()) + .await + .map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?; + + let row = query_stream + .into_row() + .await + .map_err(|e| StorageError::Other(format!("Failed to fetch row: {e}")))?; + + if let Some(row) = row { + DbRecord::from_row::(&row) + } else { + Err(StorageError::NotFound(format!( + "{:?} {:?} not found", + St::data_type(), + id + ))) + } } - async fn batch_get(&self, ids: &[St::StorageKey]) -> Result, StorageError> { - todo!() + async fn batch_get( + &self, + ids: &[St::StorageKey], + ) -> Result, StorageError> { + if ids.is_empty() { + return Ok(vec![]); + } + + let temp_table = TempTable::for_ids::(); + if temp_table.can_create() { + // AZKs does not support batch get, so we just return empty vec + return Ok(Vec::new()); + } + let create_temp_table = temp_table.create(); + let temp_table_name = &temp_table.to_string(); + + let mut conn = self.get_connection().await?; + + // Begin a transaction + conn.simple_query("BEGIN TRANSACTION") + .await + .map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?; + + let result = async { + // Use bulk_insert to insert all the ids into a temporary table + conn.simple_query(&create_temp_table) + .await + .map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?; + let mut bulk = conn + .bulk_insert(&temp_table_name) + .await + .map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?; + for row in DbRecord::get_batch_temp_table_rows::(ids)? { + bulk.send(row).await.map_err(|e| { + StorageError::Other(format!("Failed to add row to bulk insert: {e}")) + })?; + } + bulk.finalize() + .await + .map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?; + + // Read rows matching the ids from the temporary table + let get_sql = DbRecord::get_batch_statement::(); + let query_stream = conn.simple_query(&get_sql).await.map_err(|e| { + StorageError::Other(format!("Failed to execute batch get query: {e}")) + })?; + let mut records = Vec::new(); + { + let rows = query_stream + .into_first_result() + .await + .map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?; + for row in rows { + let record = DbRecord::from_row::(&row)?; + records.push(record); + } + } + + let drop_temp_table = temp_table.drop(); + conn.simple_query(&drop_temp_table) + .await + .map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?; + + Ok::, StorageError>(vec![]) + }; + + // Commit or rollback the transaction + match result.await { + Ok(records) => { + conn.simple_query("COMMIT").await.map_err(|e| { + StorageError::Transaction(format!("Failed to commit transaction: {e}")) + })?; + Ok(records) + } + Err(e) => { + conn.simple_query("ROLLBACK").await.map_err(|e| { + StorageError::Transaction(format!("Failed to rollback transaction: {e}")) + })?; + Err(e) + } + } } - async fn get_user_data(&self, username: &AkdLabel) -> Result { - todo!() + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. + async fn get_user_data(&self, raw_label: &AkdLabel) -> Result { + // Note: don't log raw_label or data as it may contain sensitive information, such as PII. + + let result = async { + let mut conn = self.get_connection().await?; + + let statement = values::get_all(raw_label); + + let query_stream = conn + .query(statement.sql(), &statement.params()) + .await + .map_err(|e| { + StorageError::Other(format!("Failed to execute query for label: {e}")) + })?; + let mut states = Vec::new(); + { + let rows = query_stream.into_first_result().await.map_err(|e| { + StorageError::Other(format!("Failed to fetch rows for label: {e}")) + })?; + for row in rows { + let record = statement.parse(&row)?; + states.push(record); + } + } + let key_data = KeyData { states }; + + Ok::(key_data) + }; + + match result.await { + Ok(data) => Ok(data), + Err(e) => Err(StorageError::Other(format!( + "Failed to get all data for label: {}", + e + ))), + } } + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. async fn get_user_state( &self, - username: &AkdLabel, + raw_label: &AkdLabel, flag: types::ValueStateRetrievalFlag, ) -> Result { - todo!() + let statement = values::get_by_flag(raw_label, flag); + let mut conn = self.get_connection().await?; + let query_stream = conn + .query(statement.sql(), &statement.params()) + .await + .map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?; + let row = query_stream + .into_row() + .await + .map_err(|e| StorageError::Other(format!("Failed to fetch first result row: {e}")))?; + if let Some(row) = row { + statement.parse(&row) + } else { + Err(StorageError::NotFound(format!( + "ValueState for label {:?} not found", + raw_label + ))) + } } + // Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's + // too restrictive for what we want, so generalize the name a bit. async fn get_user_state_versions( &self, - usernames: &[AkdLabel], + raw_labels: &[AkdLabel], flag: types::ValueStateRetrievalFlag, ) -> Result, StorageError> { - todo!() + let mut conn = self.get_connection().await?; + + let temp_table = TempTable::RawLabelSearch; + let create_temp_table = temp_table.create(); + let temp_table_name = &temp_table.to_string(); + + // Begin a transaction + conn.simple_query("BEGIN TRANSACTION") + .await + .map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?; + + let result = async { + conn.simple_query(&create_temp_table) + .await + .map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?; + + // Use bulk_insert to insert all the raw_labels into a temporary table + let mut bulk = conn + .bulk_insert(&temp_table_name) + .await + .map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?; + for raw_label in raw_labels { + let row = (raw_label.0.clone()).into_row(); + bulk.send(row).await.map_err(|e| { + StorageError::Other(format!("Failed to add row to bulk insert: {e}")) + })?; + } + bulk.finalize() + .await + .map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?; + + // read rows matching the raw_labels from the temporary table + let statement = values::get_versions_by_flag(&temp_table_name, flag); + let query_stream = conn + .query(statement.sql(), &statement.params()) + .await + .map_err(|e| { + StorageError::Other(format!("Failed to execute batch get query: {e}")) + })?; + + let mut results = HashMap::new(); + let rows = query_stream + .into_first_result() + .await + .map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?; + for row in rows { + let label_version = statement.parse(&row)?; + results.insert( + label_version.label, + (label_version.version, label_version.data), + ); + } + Ok(results) + }; + + match result.await { + Ok(results) => { + conn.simple_query("COMMIT").await.map_err(|e| { + StorageError::Transaction(format!("Failed to commit transaction: {e}")) + })?; + Ok(results) + } + Err(e) => { + conn.simple_query("ROLLBACK").await.map_err(|e| { + StorageError::Transaction(format!("Failed to rollback transaction: {e}")) + })?; + Err(e) + } + } } } diff --git a/akd/crates/akd_storage/src/ms_sql_storable.rs b/akd/crates/akd_storage/src/ms_sql_storable.rs new file mode 100644 index 0000000000..b26b66fc35 --- /dev/null +++ b/akd/crates/akd_storage/src/ms_sql_storable.rs @@ -0,0 +1,781 @@ +use akd::{ + errors::StorageError, + storage::{ + types::{DbRecord, StorageType, ValueState}, + Storable, + }, + tree_node::TreeNodeWithPreviousValue, + NodeLabel, +}; +use ms_database::{ColumnData, IntoRow, Row, ToSql, TokenRow}; + +use crate::{migrations::{ + TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_VALUES +}, temp_table::TempTable}; +use crate::sql_params::SqlParams; + +const SELECT_AZKS_DATA: &'static [&str] = &["epoch", "num_nodes"]; +const SELECT_HISTORY_TREE_NODE_DATA: &'static [&str] = &[ + "label_len", + "label_val", + "last_epoch", + "least_descendant_ep", + "parent_label_len", + "parent_label_val", + "node_type", + "left_child_len", + "left_child_label_val", + "right_child_len", + "right_child_label_val", + "hash", + "p_last_epoch", + "p_least_descendant_ep", + "p_parent_label_len", + "p_parent_label_val", + "p_node_type", + "p_left_child_len", + "p_left_child_label_val", + "p_right_child_len", + "p_right_child_label_val", + "p_hash", +]; +const SELECT_LABEL_DATA: &'static [&str] = &[ + "raw_label", + "epoch", + "version", + "node_label_val", + "node_label_len", + "data", +]; + +pub(crate) struct Statement { + sql: String, + params: SqlParams, +} + +impl Statement { + pub fn new(sql: String, params: SqlParams) -> Self { + Self { sql, params } + } + + pub fn sql(&self) -> &str { + &self.sql + } + + pub fn params(&self) -> Vec<&dyn ToSql> { + self.params.values() + } +} + +pub(crate) struct QueryStatement { + sql: String, + params: SqlParams, + parser: fn(&Row) -> Result, +} + +impl QueryStatement { + pub fn new(sql: String, params: SqlParams, parser: fn(&Row) -> Result) -> Self { + Self { sql, params, parser } + } + + pub fn sql(&self) -> &str { + &self.sql + } + + pub fn params(&self) -> Vec<&dyn ToSql> { + self.params.values() + } + + pub fn parse(&self, row: &Row) -> Result { + (self.parser)(row) + } +} + +pub(crate) trait MsSqlStorable { + fn set_statement(&self) -> Result; + + fn set_batch_statement(storage_type: &StorageType) -> String; + + fn get_statement(key: &St::StorageKey) -> Result; + + fn get_batch_temp_table_rows( + key: &[St::StorageKey], + ) -> Result, StorageError>; + + fn get_batch_statement() -> String; + + fn from_row(row: &Row) -> Result + where + Self: Sized; + + fn into_row(&self) -> Result; +} + +impl MsSqlStorable for DbRecord { + fn set_statement(&self) -> Result { + match &self { + DbRecord::Azks(azks) => { + let mut params = SqlParams::new(); + params.add("key", Box::new(1u8)); // constant key + // TODO: Fixup as conversions + params.add("epoch", Box::new(azks.latest_epoch as i64)); + params.add("num_nodes", Box::new(azks.num_nodes as i64)); + + let sql = format!( + r#" + MERGE INTO dbo.{TABLE_AZKS} AS target + USING (SELECT {}) AS source + ON target.[key] = source.[key] + WHEN MATCHED THEN + UPDATE SET {} + target.[epoch] = source.[epoch], + target.[num_nodes] = source.[num_nodes] + WHEN NOT MATCHED THEN + INSERT ({}) + VALUES ({}); + "#, + params.keys_as_columns().join(", "), + params.set_columns_equal("target.", "source.").join(", "), + params.columns().join(", "), + params.keys().join(", ") + ); + + Ok(Statement::new(sql, params)) + } + DbRecord::TreeNode(node) => { + let mut params = SqlParams::new(); + params.add("label_len", Box::new(node.label.label_len as i32)); + params.add("label_val", Box::new(node.label.label_val.to_vec())); + // Latest node values + params.add("last_epoch", Box::new(node.latest_node.last_epoch as i64)); + params.add( + "least_descendant_ep", + Box::new(node.latest_node.min_descendant_epoch as i64), + ); + params.add( + "parent_label_len", + Box::new(node.latest_node.parent.label_len as i32), + ); + params.add( + "parent_label_val", + Box::new(node.latest_node.parent.label_val.to_vec()), + ); + params.add("node_type", Box::new(node.latest_node.node_type as i16)); + params.add( + "left_child_len", + Box::new(node.latest_node.left_child.map(|lc| lc.label_len as i32)), + ); + params.add( + "left_child_val", + Box::new(node.latest_node.left_child.map(|lc| lc.label_val.to_vec())), + ); + params.add( + "right_child_len", + Box::new(node.latest_node.right_child.map(|rc| rc.label_len as i32)), + ); + params.add( + "right_child_val", + Box::new(node.latest_node.right_child.map(|rc| rc.label_val.to_vec())), + ); + params.add("[hash]", Box::new(node.latest_node.hash.0.to_vec())); + // Previous node values + params.add( + "p_last_epoch", + Box::new(node.previous_node.clone().map(|p| p.last_epoch as i64)), + ); + params.add( + "p_least_descendant_ep", + Box::new( + node.previous_node + .clone() + .map(|p| p.min_descendant_epoch as i64), + ), + ); + params.add( + "p_parent_label_len", + Box::new(node.previous_node.clone().map(|p| p.label.label_len as i32)), + ); + params.add( + "p_parent_label_val", + Box::new( + node.previous_node + .clone() + .map(|p| p.label.label_val.to_vec()), + ), + ); + params.add( + "p_node_type", + Box::new(node.previous_node.clone().map(|p| p.node_type as i16)), + ); + params.add( + "p_left_child_len", + Box::new( + node.previous_node + .clone() + .and_then(|p| p.left_child.map(|lc| lc.label_len as i32)), + ), + ); + params.add( + "p_left_child_val", + Box::new( + node.previous_node + .clone() + .and_then(|p| p.left_child.map(|lc| lc.label_val.to_vec())), + ), + ); + params.add( + "p_right_child_len", + Box::new( + node.previous_node + .clone() + .and_then(|p| p.right_child.map(|rc| rc.label_len as i32)), + ), + ); + params.add( + "p_right_child_val", + Box::new( + node.previous_node + .clone() + .and_then(|p| p.right_child.map(|rc| rc.label_val.to_vec())), + ), + ); + params.add( + "p_hash", + Box::new(node.previous_node.clone().map(|p| p.hash.0.to_vec())), + ); + + let sql = format!( + r#" + MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target + USING (SELECT {}) AS source + ON target.label_len = source.label_len AND target.label_val = source.label_val + WHEN MATCHED THEN + UPDATE SET {} + WHEN NOT MATCHED THEN + INSERT ({}) + VALUES ({}); + "#, + params.keys_as_columns().join(", "), + params + .set_columns_equal_except( + "target.", + "source.", + vec!["label_len", "label_val"] + ) + .join(", "), + params.columns().join(", "), + params.keys().join(", "), + ); + + Ok(Statement::new(sql, params)) + } + DbRecord::ValueState(state) => { + let mut params = SqlParams::new(); + params.add("raw_label", Box::new(state.get_id().0.clone())); + // TODO: Fixup as conversions + params.add("epoch", Box::new(state.epoch as i64)); + params.add("[version]", Box::new(state.version as i64)); + params.add("node_label_val", Box::new(state.label.label_val.to_vec())); + params.add("node_label_len", Box::new(state.label.label_len as i64)); + params.add("[data]", Box::new(state.value.0.clone())); + + // Note: raw_label & epoch are combined the primary key, so these are always new + let sql = format!( + r#" + INSERT INTO dbo.{TABLE_VALUES} ({}) + VALUES ({}); + "#, + params.columns().join(", "), + params.keys().join(", "), + ); + Ok(Statement::new(sql, params)) + } + } + } + + fn set_batch_statement(storage_type: &StorageType) -> String { + match storage_type { + StorageType::Azks => format!( + r#" + MERGE INTO dbo.{TABLE_AZKS} AS target + USING {} AS source + ON target.[key] = source.[key] + WHEN MATCHED THEN + UPDATE SET + target.[epoch] = source.[epoch], + target.[num_nodes] = source.[num_nodes] + WHEN NOT MATCHED THEN + INSERT ([key], [epoch], [num_nodes]) + VALUES (source.[key], source.[epoch], source.[num_nodes]); + "#, + TempTable::Azks.to_string() + ), + StorageType::TreeNode => format!( + r#" + MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target + USING {} AS source + ON target.label_len = source.label_len AND target.label_val = source.label_val + WHEN MATCHED THEN + UPDATE SET + target.last_epoch = source.last_epoch, + target.least_descendant_ep = source.least_descendant_ep, + target.parent_label_len = source.parent_label_len, + target.parent_label_val = source.parent_label_val, + target.node_type = source.node_type, + target.left_child_len = source.left_child_len, + target.left_child_label_val = source.left_child_label_val, + target.right_child_len = source.right_child_len, + target.right_child_label_val = source.right_child_label_val, + target.hash = source.hash, + target.p_last_epoch = source.p_last_epoch, + target.p_least_descendant_ep = source.p_least_descendant_ep, + target.p_parent_label_len = source.p_parent_label_len, + target.p_parent_label_val = source.p_parent_label_val, + target.p_node_type = source.p_node_type, + target.p_left_child_len = source.p_left_child_len, + target.p_left_child_label_val = source.p_left_child_label_val, + target.p_right_child_len = source.p_right_child_len, + target.p_right_child_label_val = source.p_right_child_label_val, + target.p_hash = source.p_hash + WHEN NOT MATCHED THEN + INSERT ( + label_len + , label_val + , last_epoch + , least_descendant_ep + , parent_label_len + , parent_label_val + , node_type + , left_child_len + , left_child_label_val + , right_child_len + , right_child_label_val + , hash + , p_last_epoch + , p_least_descendant_ep + , p_parent_label_len + , p_parent_label_val + , p_node_type + , p_left_child_len + , p_left_child_label_val + , p_right_child_len + , p_right_child_label_val + , p_hash + ) + VALUES ( + source.label_len + , source.label_val + , source.last_epoch + , source.least_descendant_ep + , source.parent_label_len + , source.parent_label_val + , source.node_type + , source.left_child_len + , source.left_child_label_val + , source.right_child_len + , source.right_child_label_val + , source.hash + , source.p_last_epoch + , source.p_least_descendant_ep + , source.p_parent_label_len + , source.p_parent_label_val + , source.p_node_type + , source.p_left_child_len + , source.p_left_child_label_val + , source.p_right_child_len + , source.p_right_child_label_val + , source.p_hash + ); + "#, + TempTable::HistoryTreeNodes.to_string() + ), + StorageType::ValueState => format!( + r#" + MERGE INTO dbo.{TABLE_VALUES} AS target + USING {} AS source + ON target.raw_label = source.raw_label AND target.epoch = source.epoch + WHEN MATCHED THEN + UPDATE SET + target.[version] = source.[version], + target.node_label_val = source.node_label_val, + target.node_label_len = source.node_label_len, + target.[data] = source.[data] + WHEN NOT MATCHED THEN + INSERT (raw_label, epoch, [version], node_label_val, node_label_len, [data]) + VALUES (source.raw_label, source.epoch, source.[version], source.node_label_val, source.node_label_len, source.[data]); + "#, + TempTable::Values.to_string() + ), + } + } + + fn get_statement(key: &St::StorageKey) -> Result { + let mut params = SqlParams::new(); + let sql = match St::data_type() { + StorageType::Azks => { + format!( + r#" + SELECT {} from dbo.{} LIMIT 1; + "#, + SELECT_AZKS_DATA.join(", "), + TABLE_AZKS + ) + } + StorageType::TreeNode => { + let bin = St::get_full_binary_key_id(key); + // These are constructed from a safe key, they should never fail + let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin) + .expect("Failed to decode key"); // TODO: should this be an error? + + params.add("label_len", Box::new(key.0.label_len as i32)); + params.add("label_val", Box::new(key.0.label_val.to_vec())); + + format!( + r#" + SELECT {} from dbo.{TABLE_HISTORY_TREE_NODES} WHERE [label_len] = {} AND [label_val] = {}; + "#, + SELECT_HISTORY_TREE_NODE_DATA.join(", "), + params.key_for("label_len").expect("key present"), + params.key_for("label_val").expect("key present"), + ) + } + StorageType::ValueState => { + let bin = St::get_full_binary_key_id(key); + // These are constructed from a safe key, they should never fail + let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error? + + params.add("raw_label", Box::new(key.0.clone())); + params.add("epoch", Box::new(key.1 as i64)); + format!( + r#" + SELECT {} from dbo.{TABLE_VALUES} WHERE [raw_label] = {} AND [epoch] = {}; + "#, + SELECT_LABEL_DATA.join(", "), + params.key_for("raw_label").expect("key present"), + params.key_for("epoch").expect("key present"), + ) + } + }; + + Ok(Statement::new(sql, params)) + } + + fn get_batch_temp_table_rows( + key: &[St::StorageKey], + ) -> Result, StorageError> { + match St::data_type() { + StorageType::Azks => Err(StorageError::Other( + "Batch temp table rows not supported for Azks".to_string(), + )), + StorageType::TreeNode => { + let mut rows = Vec::new(); + for k in key { + let bin = St::get_full_binary_key_id(k); + // These are constructed from a safe key, they should never fail + let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin) + .expect("Failed to decode key"); + + let row = (key.0.label_len as i32, key.0.label_val.to_vec()).into_row(); + rows.push(row); + } + Ok(rows) + } + StorageType::ValueState => { + let mut rows = Vec::new(); + for k in key { + let bin = St::get_full_binary_key_id(k); + // These are constructed from a safe key, they should never fail + let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error? + + let row = (key.0.clone(), key.1 as i64).into_row(); + rows.push(row); + } + Ok(rows) + } + } + } + + fn get_batch_statement() -> String { + // Note: any changes to these columns need to be reflected in from_row below + match St::data_type() { + StorageType::Azks => panic!("Batch get not supported for Azks"), + StorageType::TreeNode => format!( + r#" + SELECT {} + FROM dbo.{TABLE_HISTORY_TREE_NODES} h + INNER JOIN {} t + ON h.label_len = t.label_len + AND h.label_val = t.label_val; + "#, + SELECT_HISTORY_TREE_NODE_DATA + .iter() + .map(|s| format!("h.{s}")) + .collect::>() + .join(", "), + TempTable::for_ids::().to_string() + ), + StorageType::ValueState => format!( + r#" + SELECT {} + FROM dbo.{TABLE_VALUES} v + INNER JOIN {} t + ON v.raw_label = t.raw_label + AND v.epoch = t.epoch; + "#, + SELECT_LABEL_DATA + .iter() + .map(|s| format!("v.{s}")) + .collect::>() + .join(", "), + TempTable::for_ids::().to_string() + ), + } + } + + fn from_row(row: &Row) -> Result + where + Self: Sized, + { + match St::data_type() { + // TODO: check this + StorageType::Azks => { + let epoch: i64 = row + .get("epoch") + .ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?; + let num_nodes: i64 = row.get("num_nodes").ok_or_else(|| { + StorageError::Other("num_nodes is NULL or missing".to_string()) + })?; + + let azks = DbRecord::build_azks(epoch as u64, num_nodes as u64); + Ok(DbRecord::Azks(azks)) + } + StorageType::TreeNode => { + let label_len: i32 = row.get("label_len").ok_or_else(|| { + StorageError::Other("label_len is NULL or missing".to_string()) + })?; + let label_val: &[u8] = row.get("label_val").ok_or_else(|| { + StorageError::Other("label_val is NULL or missing".to_string()) + })?; + let last_epoch: i64 = row.get("last_epoch").ok_or_else(|| { + StorageError::Other("last_epoch is NULL or missing".to_string()) + })?; + let least_descendant_ep: i64 = row.get("least_descendant_ep").ok_or_else(|| { + StorageError::Other("least_descendant_ep is NULL or missing".to_string()) + })?; + let parent_label_len: i32 = row.get("parent_label_len").ok_or_else(|| { + StorageError::Other("parent_label_len is NULL or missing".to_string()) + })?; + let parent_label_val: &[u8] = row.get("parent_label_val").ok_or_else(|| { + StorageError::Other("parent_label_val is NULL or missing".to_string()) + })?; + let node_type: i16 = row.get("node_type").ok_or_else(|| { + StorageError::Other("node_type is NULL or missing".to_string()) + })?; + let left_child_len: Option = row.get("left_child_len"); + let left_child_label_val: Option<&[u8]> = row.get("left_child_label_val"); + let right_child_len: Option = row.get("right_child_len"); + let right_child_label_val: Option<&[u8]> = row.get("right_child_len"); + let hash: &[u8] = row + .get("hash") + .ok_or_else(|| StorageError::Other("hash is NULL or missing".to_string()))?; + let p_last_epoch: Option = row.get("p_last_epoch"); + let p_least_descendant_ep: Option = row.get("p_least_descendant_ep"); + let p_parent_label_len: Option = row.get("p_parent_label_len"); + let p_parent_label_val: Option<&[u8]> = row.get("p_parent_label_val"); + let p_node_type: Option = row.get("p_node_type"); + let p_left_child_len: Option = row.get("p_left_child_len"); + let p_left_child_label_val: Option<&[u8]> = row.get("p_left_child_label_val"); + let p_right_child_len: Option = row.get("p_right_child_len"); + let p_right_child_label_val: Option<&[u8]> = row.get("p_right_child_label_val"); + let p_hash: Option<&[u8]> = row.get("p_hash"); + + // Make child nodes + fn optional_child_label( + child_val: Option<&[u8]>, + child_len: Option, + ) -> Result, StorageError> { + match (child_val, child_len) { + (Some(val), Some(len)) => { + let val_vec: Vec = val.to_vec().try_into().map_err(|_| { + StorageError::Other("child_val has incorrect length".to_string()) + })?; + Ok(Some(NodeLabel::new( + val_vec.try_into().map_err(|_| { + StorageError::Other( + "child_val has incorrect length".to_string(), + ) + })?, + len as u32, + ))) + } + _ => Ok(None), + } + } + let left_child = optional_child_label(left_child_label_val, left_child_len)?; + let right_child = optional_child_label(right_child_label_val, right_child_len)?; + let p_left_child = optional_child_label(p_left_child_label_val, p_left_child_len)?; + let p_right_child = + optional_child_label(p_right_child_label_val, p_right_child_len)?; + + let massaged_p_parent_label_val: Option<[u8; 32]> = match p_parent_label_val { + Some(v) => Some(v.to_vec().try_into().map_err(|_| { + StorageError::Other("p_parent_label_val has incorrect length".to_string()) + })?), + None => None, + }; + let massaged_hash: akd::Digest = akd::hash::try_parse_digest(&hash.to_vec()) + .map_err(|_| StorageError::Other("hash has incorrect length".to_string()))?; + let massaged_p_hash: Option = match p_hash { + Some(v) => Some(akd::hash::try_parse_digest(&v.to_vec()).map_err(|_| { + StorageError::Other("p_hash has incorrect length".to_string()) + })?), + None => None, + }; + + let node = DbRecord::build_tree_node_with_previous_value( + label_val.try_into().map_err(|_| { + StorageError::Other("label_val has incorrect length".to_string()) + })?, + label_len as u32, + last_epoch as u64, + least_descendant_ep as u64, + parent_label_val.try_into().map_err(|_| { + StorageError::Other("parent_label_val has incorrect length".to_string()) + })?, + parent_label_len as u32, + node_type as u8, + left_child, + right_child, + massaged_hash, + p_last_epoch.map(|v| v as u64), + p_least_descendant_ep.map(|v| v as u64), + massaged_p_parent_label_val, + p_parent_label_len.map(|v| v as u32), + p_node_type.map(|v| v as u8), + p_left_child, + p_right_child, + massaged_p_hash, + ); + + Ok(DbRecord::TreeNode(node)) + } + StorageType::ValueState => Ok(DbRecord::ValueState(crate::tables::values::from_row(row)?)), + } + } + + fn into_row(&self) -> Result { + match &self { + DbRecord::Azks(azks) => { + let row = ( + 1u8, // constant key + azks.latest_epoch as i64, + azks.num_nodes as i64, + ) + .into_row(); + Ok(row) + } + DbRecord::TreeNode(node) => { + let mut row = TokenRow::new(); + row.push(ColumnData::I32(Some(node.label.label_len as i32))); + row.push(ColumnData::Binary(Some( + node.label.label_val.to_vec().into(), + ))); + // Latest node values + row.push(ColumnData::I64(Some(node.latest_node.last_epoch as i64))); + row.push(ColumnData::I64(Some( + node.latest_node.min_descendant_epoch as i64, + ))); + row.push(ColumnData::I32(Some( + node.latest_node.parent.label_len as i32, + ))); + row.push(ColumnData::Binary(Some( + node.latest_node.parent.label_val.to_vec().into(), + ))); + row.push(ColumnData::I16(Some(node.latest_node.node_type as i16))); + match &node.latest_node.left_child { + Some(lc) => { + row.push(ColumnData::I32(Some(lc.label_len as i32))); + row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into()))); + } + None => { + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + } + } + match &node.latest_node.right_child { + Some(rc) => { + row.push(ColumnData::I32(Some(rc.label_len as i32))); + row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into()))); + } + None => { + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + } + } + row.push(ColumnData::Binary(Some( + node.latest_node.hash.0.to_vec().into(), + ))); + // Previous node values + match &node.previous_node { + Some(p) => { + row.push(ColumnData::I64(Some(p.last_epoch as i64))); + row.push(ColumnData::I64(Some(p.min_descendant_epoch as i64))); + row.push(ColumnData::I32(Some(p.label.label_len as i32))); + row.push(ColumnData::Binary(Some(p.label.label_val.to_vec().into()))); + row.push(ColumnData::I16(Some(p.node_type as i16))); + match &p.left_child { + Some(lc) => { + row.push(ColumnData::I32(Some(lc.label_len as i32))); + row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into()))); + } + None => { + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + } + } + match &p.right_child { + Some(rc) => { + row.push(ColumnData::I32(Some(rc.label_len as i32))); + row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into()))); + } + None => { + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + } + } + row.push(ColumnData::Binary(Some(p.hash.0.to_vec().into()))); + } + None => { + // Node Values + row.push(ColumnData::I64(None)); + row.push(ColumnData::I64(None)); + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + row.push(ColumnData::I16(None)); + // Left child + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + // Right child + row.push(ColumnData::I32(None)); + row.push(ColumnData::Binary(None)); + // Hash + row.push(ColumnData::Binary(None)); + } + } + Ok(row) + } + DbRecord::ValueState(state) => { + let row = ( + state.get_id().0.clone(), + state.epoch as i64, + state.version as i64, + state.label.label_val.to_vec(), + state.label.label_len as i64, + state.value.0.clone(), + ) + .into_row(); + Ok(row) + } + } + } +} diff --git a/akd/crates/akd_storage/src/sql_params.rs b/akd/crates/akd_storage/src/sql_params.rs new file mode 100644 index 0000000000..9e694643d0 --- /dev/null +++ b/akd/crates/akd_storage/src/sql_params.rs @@ -0,0 +1,177 @@ +use ms_database::ToSql; + +pub struct SqlParam { + /// The parameter key (e.g., "@P1", "@P2") + pub key: String, + /// The column name this parameter maps to + column: String, + pub data: Box, +} + +impl SqlParam { + fn column(&self) -> String { + SqlParam::wrap_in_brackets(&self.column) + } + + fn wrap_in_brackets(s: &str) -> String { + // ensure column names are wrapped in brackets for SQL Server + let trimmed = s.trim(); + let starts_with_bracket = trimmed.starts_with('['); + let ends_with_bracket = trimmed.ends_with(']'); + + match (starts_with_bracket, ends_with_bracket) { + (true, true) => trimmed.to_string(), + (true, false) => format!("{}]", trimmed), + (false, true) => format!("[{}", trimmed), + (false, false) => format!("[{}]", trimmed), + } + } +} + +pub(crate) struct SqlParams { + params: Vec>, +} + +impl SqlParams { + pub fn new() -> Self { + Self { params: Vec::new() } + } + + pub fn add(&mut self, column: impl Into, value: Box) { + self.params.push(Box::new(SqlParam { + key: format!("@P{}", self.params.len() + 1), + column: column.into(), + data: value, + })); + } + + pub fn keys(&self) -> Vec { + self.params.iter().map(|p| p.key.clone()).collect() + } + + pub fn keys_as_columns(&self) -> Vec { + self.params + .iter() + .map(|p| format!("{} AS {}", p.key, p.column)) + .collect() + } + + pub fn keys_except_columns(&self, excludes: Vec<&str>) -> Vec { + self.params + .iter() + .filter(|p| !excludes.contains(&p.column.as_str())) + .map(|p| p.key.clone()) + .collect() + } + + pub fn key_for(&self, column: &str) -> Option { + self.params + .iter() + .find(|p| p.column == column) + .map(|p| p.key.clone()) + } + + pub fn columns(&self) -> Vec { + self.params + .iter() + .map(|p| p.column()) + .collect() + } + + pub fn columns_except(&self, excludes: Vec<&str>) -> Vec { + self.params + .iter() + .filter(|p| !excludes.contains(&p.column.as_str())) + .map(|p| p.column()) + .collect() + } + + pub fn columns_prefix_with(&self, prefix: &str) -> Vec { + self.params + .iter() + .map(|p| format!("{}{}", prefix, p.column())) + .collect() + } + + pub fn set_columns_equal(&self, assign_prefix: &str, source_prefix: &str) -> Vec { + self.params + .iter() + .map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column())) + .collect() + } + + pub fn set_columns_equal_except( + &self, + assign_prefix: &str, + source_prefix: &str, + excludes: Vec<&str>, + ) -> Vec { + self.params + .iter() + .filter(|p| !excludes.contains(&p.column.as_str())) + .map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column())) + .collect() + } + + pub fn values(&self) -> Vec<&(dyn ToSql)> { + self.params + .iter() + .map(|b| b.data.as_ref() as &(dyn ToSql)) + .collect() + } +} + +pub struct VecStringBuilder<'a> { + strings: Vec, + ops: Vec>, +} + +enum StringBuilderOperation<'a> { + StringOperation(Box String + 'a>), + VectorOperation(Box) -> Vec + 'a>), +} + +impl<'a> VecStringBuilder<'a> { + pub fn new(strings: Vec) -> Self { + Self { + strings, + ops: Vec::new(), + } + } + + pub fn map(&mut self, op: F) + where + F: Fn(String) -> String + 'static, + { + self.ops.push(StringBuilderOperation::StringOperation(Box::new(op))); + } + + pub fn build(self) -> Vec { + let mut result = self.strings; + for op in self.ops { + match op { + StringBuilderOperation::StringOperation(f) => { + result = result.into_iter().map(f.as_ref()).collect(); + } + StringBuilderOperation::VectorOperation(f) => { + result = f(result); + } + } + } + result + } + + pub fn join(self, sep: &str) -> String { + self.build().join(sep) + } + + pub fn except(&mut self, excludes: Vec<&'a str>) { + self.ops.push(StringBuilderOperation::VectorOperation(Box::new( + move |vec: Vec| { + vec.into_iter() + .filter(|s| !excludes.contains(&s.as_str())) + .collect() + }, + ))); + } +} diff --git a/akd/crates/akd_storage/src/tables/mod.rs b/akd/crates/akd_storage/src/tables/mod.rs new file mode 100644 index 0000000000..4b9235503b --- /dev/null +++ b/akd/crates/akd_storage/src/tables/mod.rs @@ -0,0 +1 @@ +pub(crate) mod values; diff --git a/akd/crates/akd_storage/src/tables/values.rs b/akd/crates/akd_storage/src/tables/values.rs new file mode 100644 index 0000000000..37ae7710c7 --- /dev/null +++ b/akd/crates/akd_storage/src/tables/values.rs @@ -0,0 +1,190 @@ +use akd::{ + errors::StorageError, + storage::types::{DbRecord, ValueState, ValueStateRetrievalFlag}, + AkdLabel, AkdValue, +}; +use ms_database::Row; + +use crate::{migrations::TABLE_VALUES, ms_sql_storable::{QueryStatement, Statement}, sql_params::SqlParams}; + +pub fn get_all(raw_label: &AkdLabel) -> QueryStatement { + let mut params = SqlParams::new(); + // the raw vector is the key for value storage + params.add("raw_label", Box::new(raw_label.0.clone())); + + let sql = format!( + r#" + SELECT raw_label, epoch, version, node_label_val, node_label_len, data + FROM {} + WHERE raw_label = {} + "#, + TABLE_VALUES, params + .key_for("raw_label") + .expect("raw_label was added to the params list") + ); + QueryStatement::new(sql, params, from_row) +} + +pub fn get_by_flag(raw_label: &AkdLabel, flag: ValueStateRetrievalFlag) -> QueryStatement { + let mut params = SqlParams::new(); + params.add("raw_label", Box::new(raw_label.0.clone())); + + match flag { + ValueStateRetrievalFlag::SpecificEpoch(epoch) + | ValueStateRetrievalFlag::LeqEpoch(epoch) => params.add("epoch", Box::new(epoch as i64)), + ValueStateRetrievalFlag::SpecificVersion(version) => { + params.add("version", Box::new(version as i64)) + } + _ => {} + } + + let mut sql = format!( + r#" + SELECT TOP(1) raw_label, epoch, version, node_label_val, node_label_len, data + FROM {} + WHERE raw_label = {} + "#, + TABLE_VALUES, + params + .key_for("raw_label") + .expect("raw_label was added to the params list") + ); + + match flag { + ValueStateRetrievalFlag::SpecificEpoch(_) => { + sql.push_str(&format!( + " AND epoch = {}", + ¶ms + .key_for("epoch") + .expect("epoch was added to the params list") + )); + } + ValueStateRetrievalFlag::SpecificVersion(_) => { + sql.push_str(&format!( + " AND version = {}", + ¶ms + .key_for("version") + .expect("version was added to the params list") + )); + } + ValueStateRetrievalFlag::MaxEpoch => { + sql.push_str(" ORDER BY epoch DESC "); + } + ValueStateRetrievalFlag::MinEpoch => { + sql.push_str(" ORDER BY epoch ASC "); + } + ValueStateRetrievalFlag::LeqEpoch(_) => { + sql.push_str(&format!( + r#" + AND epoch <= {} + ORDER BY epoch DESC + "#, + ¶ms + .key_for("epoch") + .expect("epoch was added to the params list") + )); + } + } + + QueryStatement::new(sql, params, from_row) +} + +pub fn get_versions_by_flag( + temp_table_name: &str, + flag: ValueStateRetrievalFlag, +) -> QueryStatement { + let mut params = SqlParams::new(); + + let (filter, epoch_col) = match flag { + ValueStateRetrievalFlag::SpecificVersion(version) => { + params.add("version", Box::new(version as i64)); + (format!("WHERE tmp.version = {}", params.key_for("version").expect("version was added to the params list")), "tmp.epoch") + } + ValueStateRetrievalFlag::SpecificEpoch(epoch) => { + params.add("epoch", Box::new(epoch as i64)); + (format!("WHERE tmp.epoch = {}", params.key_for("epoch").expect("epoch was added to the params list")), "tmp.epoch") + } + ValueStateRetrievalFlag::LeqEpoch(epoch) => { + params.add("epoch", Box::new(epoch as i64)); + (format!("WHERE tmp.epoch <= {}", params.key_for("epoch").expect("epoch was added to the params list")), "MAX(tmp.epoch)") + } + ValueStateRetrievalFlag::MaxEpoch => ("".to_string(), "MAX(tmp.epoch)"), + ValueStateRetrievalFlag::MinEpoch => ("".to_string(), "MIN(tmp.epoch)"), + }; + + let sql = format!( + r#" + SELECT t.raw_label, t.version, t.data + FROM {TABLE_VALUES} t + INNER JOIN ( + SELECT tmp.raw_label as raw_label, {} as epoch + FROM {TABLE_VALUES} tmp + INNER JOIN {temp_table_name} s ON s.raw_label = tmp.raw_label + {} + GROUP BY tmp.raw_label + ) epochs on epochs.raw_label = t.raw_label AND epochs.epoch = t.epoch + "#, + epoch_col, + filter, + ); + + QueryStatement::new(sql, params, version_from_row) +} + +pub(crate) fn from_row(row: &Row) -> Result { + let raw_label: &[u8] = row + .get("raw_label") + .ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?; + let epoch: i64 = row + .get("epoch") + .ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?; + let version: i64 = row + .get("version") + .ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?; + let node_label_val: &[u8] = row + .get("node_label_val") + .ok_or_else(|| StorageError::Other("node_label_val is NULL or missing".to_string()))?; + let node_label_len: i64 = row + .get("node_label_len") + .ok_or_else(|| StorageError::Other("node_label_len is NULL or missing".to_string()))?; + let data: &[u8] = row + .get("data") + .ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?; + + let state = DbRecord::build_user_state( + raw_label.to_vec(), + data.to_vec(), + version as u64, + node_label_len as u32, + node_label_val + .to_vec() + .try_into() + .map_err(|_| StorageError::Other("node_label_val has incorrect length".to_string()))?, + epoch as u64, + ); + Ok(state) +} + +pub(crate) struct LabelVersion { + pub label: AkdLabel, + pub version: u64, + pub data: AkdValue, +} + +fn version_from_row(row: &Row) -> Result { + let raw_label: &[u8] = row + .get("raw_label") + .ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?; + let version: i64 = row + .get("version") + .ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?; + let data: &[u8] = row + .get("data") + .ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?; + + Ok(LabelVersion { + label: AkdLabel(raw_label.to_vec()), + version: version as u64, + data: AkdValue(data.to_vec()), + }) +} diff --git a/akd/crates/akd_storage/src/temp_table.rs b/akd/crates/akd_storage/src/temp_table.rs index a06da767b9..a2597ac2f0 100644 --- a/akd/crates/akd_storage/src/temp_table.rs +++ b/akd/crates/akd_storage/src/temp_table.rs @@ -7,13 +7,10 @@ pub(crate) enum TempTable { Azks, HistoryTreeNodes, Values, + RawLabelSearch, } impl TempTable { - pub fn for_ids_for(storage_type: &StorageType) -> Self { - TempTable::Ids(storage_type.clone()) - } - pub fn for_ids() -> Self { TempTable::Ids(St::data_type()) } @@ -111,6 +108,15 @@ impl TempTable { TEMP_IDS_TABLE ), } + TempTable::RawLabelSearch => format!( + r#" + CREATE TABLE {} ( + raw_label VARBINARY(256) NOT NULL, + PRIMARY KEY (raw_label) + ); + "#, + TEMP_SEARCH_LABELS_TABLE + ) } } } @@ -122,6 +128,7 @@ impl ToString for TempTable { TempTable::Azks => TEMP_AZKS_TABLE.to_string(), TempTable::HistoryTreeNodes => TEMP_HISTORY_TREE_NODES_TABLE.to_string(), TempTable::Values => TEMP_VALUES_TABLE.to_string(), + TempTable::RawLabelSearch => TEMP_SEARCH_LABELS_TABLE.to_string(), } } } @@ -137,6 +144,7 @@ impl From for TempTable { } pub(crate) const TEMP_IDS_TABLE: &str = "#akd_temp_ids"; +pub(crate) const TEMP_SEARCH_LABELS_TABLE: &str = "#akd_temp_search_labels"; pub(crate) const TEMP_AZKS_TABLE: &str = "#akd_temp_azks"; pub(crate) const TEMP_HISTORY_TREE_NODES_TABLE: &str = "#akd_temp_history_tree_nodes"; pub(crate) const TEMP_VALUES_TABLE: &str = "#akd_temp_values"; diff --git a/akd/crates/ms_database/src/lib.rs b/akd/crates/ms_database/src/lib.rs index 31c109322a..33823dcea2 100644 --- a/akd/crates/ms_database/src/lib.rs +++ b/akd/crates/ms_database/src/lib.rs @@ -6,7 +6,7 @@ pub use pool::ConnectionManager as MsSqlConnectionManager; pub use pool::{OnConnectError}; // re-expose tiberius types for convenience -pub use tiberius::{error::Error as MsDbError, Column, Row, ToSql}; +pub use tiberius::{error::Error as MsDbError, Column, Row, FromSql, ToSql, IntoRow, TokenRow, ColumnData}; // re-expose bb8 types for convenience pub type Pool = bb8::Pool;