1
0
mirror of https://github.com/bitwarden/server synced 2026-01-30 00:03:48 +00:00

First complete implementation of Database trait for sql server

This commit is contained in:
Matt Gibson
2025-10-21 13:16:05 -07:00
parent 7bb8296a2c
commit bbd1a230a6
8 changed files with 1616 additions and 21 deletions

View File

@@ -2,4 +2,5 @@ mod migrations;
mod sql_params;
mod ms_sql_storable;
pub mod ms_sql;
mod tables;
mod temp_table;

View File

@@ -1,48 +1,485 @@
use std::collections::HashMap;
use std::{cmp::Ordering, collections::HashMap, sync::Arc};
use akd::{errors::StorageError, storage::{types::{self, DbRecord}, Database, DbSetState, Storable}, AkdLabel, AkdValue};
use akd::{
errors::StorageError,
storage::{
types::{self, DbRecord, KeyData, StorageType},
Database, DbSetState, Storable,
},
AkdLabel, AkdValue,
};
use async_trait::async_trait;
use ms_database::ConnectionManager;
use ms_database::{IntoRow, MsSqlConnectionManager, Pool, PooledConnection};
use crate::{
migrations::MIGRATIONS,
ms_sql_storable::{MsSqlStorable, Statement},
tables::values,
temp_table::TempTable,
};
const DEFAULT_POOL_SIZE: u32 = 100;
pub struct MsSqlBuilder {
connection_string: String,
pool_size: Option<u32>,
}
impl MsSqlBuilder {
/// Create a new [MsSqlBuilder] with the given connection string.
pub fn new(connection_string: String) -> Self {
Self {
connection_string,
pool_size: None,
}
}
/// Set the connection pool size. Default is given by [DEFAULT_POOL_SIZE].
pub fn pool_size(mut self, pool_size: u32) -> Self {
self.pool_size = Some(pool_size);
self
}
/// Build the [MsSql] instance.
pub async fn build(self) -> Result<MsSql, StorageError> {
let pool_size = self.pool_size.unwrap_or(DEFAULT_POOL_SIZE);
MsSql::new(self.connection_string, pool_size).await
}
}
pub struct MsSql {
connection_manager: ConnectionManager,
pool: Arc<Pool>,
}
impl MsSql {
pub fn builder(connection_string: String) -> MsSqlBuilder {
MsSqlBuilder::new(connection_string)
}
pub async fn new(connection_string: String, pool_size: u32) -> Result<Self, StorageError> {
let connection_manager = MsSqlConnectionManager::new(connection_string);
let pool = Pool::builder()
.max_size(pool_size)
.build(connection_manager)
.await
.map_err(|e| StorageError::Connection(format!("Failed to create DB pool: {}", e)))?;
Ok(Self {
pool: Arc::new(pool),
})
}
pub async fn migrate(&self) -> Result<(), StorageError> {
let mut conn = self.pool.get().await.map_err(|e| {
StorageError::Connection(format!("Failed to get DB connection for migrations: {}", e))
})?;
ms_database::run_pending_migrations(&mut conn, MIGRATIONS)
.await
.map_err(|e| StorageError::Connection(format!("Failed to run migrations: {}", e)))?;
Ok(())
}
async fn get_connection(&self) -> Result<PooledConnection<'_>, StorageError> {
self.pool
.get()
.await
.map_err(|e| StorageError::Connection(format!("Failed to get DB connection: {}", e)))
}
async fn execute_statement(&self, statement: &Statement) -> Result<(), StorageError> {
let mut conn = self.get_connection().await?;
self.execute_statement_on_connection(statement, &mut conn)
.await
}
async fn execute_statement_on_connection(
&self,
statement: &Statement,
conn: &mut PooledConnection<'_>,
) -> Result<(), StorageError> {
conn.execute(statement.sql(), &statement.params())
.await
.map_err(|e| StorageError::Other(format!("Failed to execute statement: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl Database for MsSql {
async fn set(&self, record: DbRecord) -> Result<(), StorageError> {
todo!()
let statement = record.set_statement()?;
self.execute_statement(&statement).await
}
async fn batch_set(&self, records: Vec<DbRecord>, state: DbSetState) -> Result<(), StorageError> {
todo!()
async fn batch_set(
&self,
records: Vec<DbRecord>,
_state: DbSetState, // TODO: unused in mysql example, but may be needed later
) -> Result<(), StorageError> {
// Generate groups by type
let mut groups = HashMap::new();
for record in records {
match &record {
DbRecord::Azks(_) => groups
.entry(StorageType::Azks)
.or_insert_with(Vec::new)
.push(record),
DbRecord::TreeNode(_) => groups
.entry(StorageType::TreeNode)
.or_insert_with(Vec::new)
.push(record),
DbRecord::ValueState(_) => groups
.entry(StorageType::ValueState)
.or_insert_with(Vec::new)
.push(record),
}
}
// Execute each group in batches
let mut conn = self.get_connection().await?;
// Start transaction
conn.simple_query("BEGIN TRANSACTION")
.await
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
let result = async {
for (storage_type, mut record_group) in groups.into_iter() {
if record_group.is_empty() {
continue;
}
// Sort the records to match db-layer sorting
record_group.sort_by(|a, b| match &a {
DbRecord::TreeNode(node) => {
if let DbRecord::TreeNode(node2) = &b {
node.label.cmp(&node2.label)
} else {
Ordering::Equal
}
}
DbRecord::ValueState(state) => {
if let DbRecord::ValueState(state2) = &b {
match state.username.0.cmp(&state2.username.0) {
Ordering::Equal => state.epoch.cmp(&state2.epoch),
other => other,
}
} else {
Ordering::Equal
}
}
_ => Ordering::Equal,
});
// Execute value as bulk insert
// Create temp table
let table: TempTable = storage_type.into();
conn.simple_query(&table.create()).await.map_err(|e| {
StorageError::Other(format!("Failed to create temp table: {e}"))
})?;
// Create bulk insert
let table_name = &table.to_string();
let mut bulk = conn.bulk_insert(table_name).await.map_err(|e| {
StorageError::Other(format!("Failed to start bulk insert: {e}"))
})?;
for record in &record_group {
let row = record.into_row()?;
bulk.send(row).await.map_err(|e| {
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
})?;
}
bulk.finalize().await.map_err(|e| {
StorageError::Other(format!("Failed to finalize bulk insert: {e}"))
})?;
// Set values from temp table to main table
let sql = <DbRecord as MsSqlStorable>::set_batch_statement(&storage_type);
conn.simple_query(&sql).await.map_err(|e| {
StorageError::Other(format!("Failed to execute batch set statement: {e}"))
})?;
// Delete the temp table
conn.simple_query(&table.drop())
.await
.map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?;
}
Ok::<(), StorageError>(())
};
match result.await {
Ok(_) => {
conn.simple_query("COMMIT").await.map_err(|e| {
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
})?;
Ok(())
}
Err(e) => {
conn.simple_query("ROLLBACK").await.map_err(|e| {
StorageError::Transaction(format!("Failed to roll back transaction: {e}"))
})?;
Err(StorageError::Other(format!(
"Failed to batch set records: {}",
e
)))
}
}
}
async fn get<St: Storable>(&self, id: &St::StorageKey) -> Result<DbRecord, StorageError> {
todo!()
let mut conn = self.get_connection().await?;
let statement = DbRecord::get_statement::<St>(id)?;
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?;
let row = query_stream
.into_row()
.await
.map_err(|e| StorageError::Other(format!("Failed to fetch row: {e}")))?;
if let Some(row) = row {
DbRecord::from_row::<St>(&row)
} else {
Err(StorageError::NotFound(format!(
"{:?} {:?} not found",
St::data_type(),
id
)))
}
}
async fn batch_get<St: Storable>(&self, ids: &[St::StorageKey]) -> Result<Vec<DbRecord>, StorageError> {
todo!()
async fn batch_get<St: Storable>(
&self,
ids: &[St::StorageKey],
) -> Result<Vec<DbRecord>, StorageError> {
if ids.is_empty() {
return Ok(vec![]);
}
let temp_table = TempTable::for_ids::<St>();
if temp_table.can_create() {
// AZKs does not support batch get, so we just return empty vec
return Ok(Vec::new());
}
let create_temp_table = temp_table.create();
let temp_table_name = &temp_table.to_string();
let mut conn = self.get_connection().await?;
// Begin a transaction
conn.simple_query("BEGIN TRANSACTION")
.await
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
let result = async {
// Use bulk_insert to insert all the ids into a temporary table
conn.simple_query(&create_temp_table)
.await
.map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?;
let mut bulk = conn
.bulk_insert(&temp_table_name)
.await
.map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?;
for row in DbRecord::get_batch_temp_table_rows::<St>(ids)? {
bulk.send(row).await.map_err(|e| {
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
})?;
}
bulk.finalize()
.await
.map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?;
// Read rows matching the ids from the temporary table
let get_sql = DbRecord::get_batch_statement::<St>();
let query_stream = conn.simple_query(&get_sql).await.map_err(|e| {
StorageError::Other(format!("Failed to execute batch get query: {e}"))
})?;
let mut records = Vec::new();
{
let rows = query_stream
.into_first_result()
.await
.map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?;
for row in rows {
let record = DbRecord::from_row::<St>(&row)?;
records.push(record);
}
}
let drop_temp_table = temp_table.drop();
conn.simple_query(&drop_temp_table)
.await
.map_err(|e| StorageError::Other(format!("Failed to drop temp table: {e}")))?;
Ok::<Vec<DbRecord>, StorageError>(vec![])
};
// Commit or rollback the transaction
match result.await {
Ok(records) => {
conn.simple_query("COMMIT").await.map_err(|e| {
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
})?;
Ok(records)
}
Err(e) => {
conn.simple_query("ROLLBACK").await.map_err(|e| {
StorageError::Transaction(format!("Failed to rollback transaction: {e}"))
})?;
Err(e)
}
}
}
async fn get_user_data(&self, username: &AkdLabel) -> Result<types::KeyData, StorageError> {
todo!()
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_data(&self, raw_label: &AkdLabel) -> Result<types::KeyData, StorageError> {
// Note: don't log raw_label or data as it may contain sensitive information, such as PII.
let result = async {
let mut conn = self.get_connection().await?;
let statement = values::get_all(raw_label);
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|e| {
StorageError::Other(format!("Failed to execute query for label: {e}"))
})?;
let mut states = Vec::new();
{
let rows = query_stream.into_first_result().await.map_err(|e| {
StorageError::Other(format!("Failed to fetch rows for label: {e}"))
})?;
for row in rows {
let record = statement.parse(&row)?;
states.push(record);
}
}
let key_data = KeyData { states };
Ok::<KeyData, StorageError>(key_data)
};
match result.await {
Ok(data) => Ok(data),
Err(e) => Err(StorageError::Other(format!(
"Failed to get all data for label: {}",
e
))),
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state(
&self,
username: &AkdLabel,
raw_label: &AkdLabel,
flag: types::ValueStateRetrievalFlag,
) -> Result<types::ValueState, StorageError> {
todo!()
let statement = values::get_by_flag(raw_label, flag);
let mut conn = self.get_connection().await?;
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|e| StorageError::Other(format!("Failed to execute query: {e}")))?;
let row = query_stream
.into_row()
.await
.map_err(|e| StorageError::Other(format!("Failed to fetch first result row: {e}")))?;
if let Some(row) = row {
statement.parse(&row)
} else {
Err(StorageError::NotFound(format!(
"ValueState for label {:?} not found",
raw_label
)))
}
}
// Note: user and username here is the raw_label. The assumption is this is a single key for a single user, but that's
// too restrictive for what we want, so generalize the name a bit.
async fn get_user_state_versions(
&self,
usernames: &[AkdLabel],
raw_labels: &[AkdLabel],
flag: types::ValueStateRetrievalFlag,
) -> Result<HashMap<AkdLabel, (u64, AkdValue)>, StorageError> {
todo!()
let mut conn = self.get_connection().await?;
let temp_table = TempTable::RawLabelSearch;
let create_temp_table = temp_table.create();
let temp_table_name = &temp_table.to_string();
// Begin a transaction
conn.simple_query("BEGIN TRANSACTION")
.await
.map_err(|e| StorageError::Transaction(format!("Failed to begin transaction: {e}")))?;
let result = async {
conn.simple_query(&create_temp_table)
.await
.map_err(|e| StorageError::Other(format!("Failed to create temp table: {e}")))?;
// Use bulk_insert to insert all the raw_labels into a temporary table
let mut bulk = conn
.bulk_insert(&temp_table_name)
.await
.map_err(|e| StorageError::Other(format!("Failed to start bulk insert: {e}")))?;
for raw_label in raw_labels {
let row = (raw_label.0.clone()).into_row();
bulk.send(row).await.map_err(|e| {
StorageError::Other(format!("Failed to add row to bulk insert: {e}"))
})?;
}
bulk.finalize()
.await
.map_err(|e| StorageError::Other(format!("Failed to finalize bulk insert: {e}")))?;
// read rows matching the raw_labels from the temporary table
let statement = values::get_versions_by_flag(&temp_table_name, flag);
let query_stream = conn
.query(statement.sql(), &statement.params())
.await
.map_err(|e| {
StorageError::Other(format!("Failed to execute batch get query: {e}"))
})?;
let mut results = HashMap::new();
let rows = query_stream
.into_first_result()
.await
.map_err(|e| StorageError::Other(format!("Failed to fetch rows: {e}")))?;
for row in rows {
let label_version = statement.parse(&row)?;
results.insert(
label_version.label,
(label_version.version, label_version.data),
);
}
Ok(results)
};
match result.await {
Ok(results) => {
conn.simple_query("COMMIT").await.map_err(|e| {
StorageError::Transaction(format!("Failed to commit transaction: {e}"))
})?;
Ok(results)
}
Err(e) => {
conn.simple_query("ROLLBACK").await.map_err(|e| {
StorageError::Transaction(format!("Failed to rollback transaction: {e}"))
})?;
Err(e)
}
}
}
}

View File

@@ -0,0 +1,781 @@
use akd::{
errors::StorageError,
storage::{
types::{DbRecord, StorageType, ValueState},
Storable,
},
tree_node::TreeNodeWithPreviousValue,
NodeLabel,
};
use ms_database::{ColumnData, IntoRow, Row, ToSql, TokenRow};
use crate::{migrations::{
TABLE_AZKS, TABLE_HISTORY_TREE_NODES, TABLE_VALUES
}, temp_table::TempTable};
use crate::sql_params::SqlParams;
const SELECT_AZKS_DATA: &'static [&str] = &["epoch", "num_nodes"];
const SELECT_HISTORY_TREE_NODE_DATA: &'static [&str] = &[
"label_len",
"label_val",
"last_epoch",
"least_descendant_ep",
"parent_label_len",
"parent_label_val",
"node_type",
"left_child_len",
"left_child_label_val",
"right_child_len",
"right_child_label_val",
"hash",
"p_last_epoch",
"p_least_descendant_ep",
"p_parent_label_len",
"p_parent_label_val",
"p_node_type",
"p_left_child_len",
"p_left_child_label_val",
"p_right_child_len",
"p_right_child_label_val",
"p_hash",
];
const SELECT_LABEL_DATA: &'static [&str] = &[
"raw_label",
"epoch",
"version",
"node_label_val",
"node_label_len",
"data",
];
pub(crate) struct Statement {
sql: String,
params: SqlParams,
}
impl Statement {
pub fn new(sql: String, params: SqlParams) -> Self {
Self { sql, params }
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn params(&self) -> Vec<&dyn ToSql> {
self.params.values()
}
}
pub(crate) struct QueryStatement<Out> {
sql: String,
params: SqlParams,
parser: fn(&Row) -> Result<Out, StorageError>,
}
impl<Out> QueryStatement<Out> {
pub fn new(sql: String, params: SqlParams, parser: fn(&Row) -> Result<Out, StorageError>) -> Self {
Self { sql, params, parser }
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn params(&self) -> Vec<&dyn ToSql> {
self.params.values()
}
pub fn parse(&self, row: &Row) -> Result<Out, StorageError> {
(self.parser)(row)
}
}
pub(crate) trait MsSqlStorable {
fn set_statement(&self) -> Result<Statement, StorageError>;
fn set_batch_statement(storage_type: &StorageType) -> String;
fn get_statement<St: Storable>(key: &St::StorageKey) -> Result<Statement, StorageError>;
fn get_batch_temp_table_rows<St: Storable>(
key: &[St::StorageKey],
) -> Result<Vec<TokenRow>, StorageError>;
fn get_batch_statement<St: Storable>() -> String;
fn from_row<St: Storable>(row: &Row) -> Result<Self, StorageError>
where
Self: Sized;
fn into_row(&self) -> Result<TokenRow, StorageError>;
}
impl MsSqlStorable for DbRecord {
fn set_statement(&self) -> Result<Statement, StorageError> {
match &self {
DbRecord::Azks(azks) => {
let mut params = SqlParams::new();
params.add("key", Box::new(1u8)); // constant key
// TODO: Fixup as conversions
params.add("epoch", Box::new(azks.latest_epoch as i64));
params.add("num_nodes", Box::new(azks.num_nodes as i64));
let sql = format!(
r#"
MERGE INTO dbo.{TABLE_AZKS} AS target
USING (SELECT {}) AS source
ON target.[key] = source.[key]
WHEN MATCHED THEN
UPDATE SET {}
target.[epoch] = source.[epoch],
target.[num_nodes] = source.[num_nodes]
WHEN NOT MATCHED THEN
INSERT ({})
VALUES ({});
"#,
params.keys_as_columns().join(", "),
params.set_columns_equal("target.", "source.").join(", "),
params.columns().join(", "),
params.keys().join(", ")
);
Ok(Statement::new(sql, params))
}
DbRecord::TreeNode(node) => {
let mut params = SqlParams::new();
params.add("label_len", Box::new(node.label.label_len as i32));
params.add("label_val", Box::new(node.label.label_val.to_vec()));
// Latest node values
params.add("last_epoch", Box::new(node.latest_node.last_epoch as i64));
params.add(
"least_descendant_ep",
Box::new(node.latest_node.min_descendant_epoch as i64),
);
params.add(
"parent_label_len",
Box::new(node.latest_node.parent.label_len as i32),
);
params.add(
"parent_label_val",
Box::new(node.latest_node.parent.label_val.to_vec()),
);
params.add("node_type", Box::new(node.latest_node.node_type as i16));
params.add(
"left_child_len",
Box::new(node.latest_node.left_child.map(|lc| lc.label_len as i32)),
);
params.add(
"left_child_val",
Box::new(node.latest_node.left_child.map(|lc| lc.label_val.to_vec())),
);
params.add(
"right_child_len",
Box::new(node.latest_node.right_child.map(|rc| rc.label_len as i32)),
);
params.add(
"right_child_val",
Box::new(node.latest_node.right_child.map(|rc| rc.label_val.to_vec())),
);
params.add("[hash]", Box::new(node.latest_node.hash.0.to_vec()));
// Previous node values
params.add(
"p_last_epoch",
Box::new(node.previous_node.clone().map(|p| p.last_epoch as i64)),
);
params.add(
"p_least_descendant_ep",
Box::new(
node.previous_node
.clone()
.map(|p| p.min_descendant_epoch as i64),
),
);
params.add(
"p_parent_label_len",
Box::new(node.previous_node.clone().map(|p| p.label.label_len as i32)),
);
params.add(
"p_parent_label_val",
Box::new(
node.previous_node
.clone()
.map(|p| p.label.label_val.to_vec()),
),
);
params.add(
"p_node_type",
Box::new(node.previous_node.clone().map(|p| p.node_type as i16)),
);
params.add(
"p_left_child_len",
Box::new(
node.previous_node
.clone()
.and_then(|p| p.left_child.map(|lc| lc.label_len as i32)),
),
);
params.add(
"p_left_child_val",
Box::new(
node.previous_node
.clone()
.and_then(|p| p.left_child.map(|lc| lc.label_val.to_vec())),
),
);
params.add(
"p_right_child_len",
Box::new(
node.previous_node
.clone()
.and_then(|p| p.right_child.map(|rc| rc.label_len as i32)),
),
);
params.add(
"p_right_child_val",
Box::new(
node.previous_node
.clone()
.and_then(|p| p.right_child.map(|rc| rc.label_val.to_vec())),
),
);
params.add(
"p_hash",
Box::new(node.previous_node.clone().map(|p| p.hash.0.to_vec())),
);
let sql = format!(
r#"
MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target
USING (SELECT {}) AS source
ON target.label_len = source.label_len AND target.label_val = source.label_val
WHEN MATCHED THEN
UPDATE SET {}
WHEN NOT MATCHED THEN
INSERT ({})
VALUES ({});
"#,
params.keys_as_columns().join(", "),
params
.set_columns_equal_except(
"target.",
"source.",
vec!["label_len", "label_val"]
)
.join(", "),
params.columns().join(", "),
params.keys().join(", "),
);
Ok(Statement::new(sql, params))
}
DbRecord::ValueState(state) => {
let mut params = SqlParams::new();
params.add("raw_label", Box::new(state.get_id().0.clone()));
// TODO: Fixup as conversions
params.add("epoch", Box::new(state.epoch as i64));
params.add("[version]", Box::new(state.version as i64));
params.add("node_label_val", Box::new(state.label.label_val.to_vec()));
params.add("node_label_len", Box::new(state.label.label_len as i64));
params.add("[data]", Box::new(state.value.0.clone()));
// Note: raw_label & epoch are combined the primary key, so these are always new
let sql = format!(
r#"
INSERT INTO dbo.{TABLE_VALUES} ({})
VALUES ({});
"#,
params.columns().join(", "),
params.keys().join(", "),
);
Ok(Statement::new(sql, params))
}
}
}
fn set_batch_statement(storage_type: &StorageType) -> String {
match storage_type {
StorageType::Azks => format!(
r#"
MERGE INTO dbo.{TABLE_AZKS} AS target
USING {} AS source
ON target.[key] = source.[key]
WHEN MATCHED THEN
UPDATE SET
target.[epoch] = source.[epoch],
target.[num_nodes] = source.[num_nodes]
WHEN NOT MATCHED THEN
INSERT ([key], [epoch], [num_nodes])
VALUES (source.[key], source.[epoch], source.[num_nodes]);
"#,
TempTable::Azks.to_string()
),
StorageType::TreeNode => format!(
r#"
MERGE INTO dbo.{TABLE_HISTORY_TREE_NODES} AS target
USING {} AS source
ON target.label_len = source.label_len AND target.label_val = source.label_val
WHEN MATCHED THEN
UPDATE SET
target.last_epoch = source.last_epoch,
target.least_descendant_ep = source.least_descendant_ep,
target.parent_label_len = source.parent_label_len,
target.parent_label_val = source.parent_label_val,
target.node_type = source.node_type,
target.left_child_len = source.left_child_len,
target.left_child_label_val = source.left_child_label_val,
target.right_child_len = source.right_child_len,
target.right_child_label_val = source.right_child_label_val,
target.hash = source.hash,
target.p_last_epoch = source.p_last_epoch,
target.p_least_descendant_ep = source.p_least_descendant_ep,
target.p_parent_label_len = source.p_parent_label_len,
target.p_parent_label_val = source.p_parent_label_val,
target.p_node_type = source.p_node_type,
target.p_left_child_len = source.p_left_child_len,
target.p_left_child_label_val = source.p_left_child_label_val,
target.p_right_child_len = source.p_right_child_len,
target.p_right_child_label_val = source.p_right_child_label_val,
target.p_hash = source.p_hash
WHEN NOT MATCHED THEN
INSERT (
label_len
, label_val
, last_epoch
, least_descendant_ep
, parent_label_len
, parent_label_val
, node_type
, left_child_len
, left_child_label_val
, right_child_len
, right_child_label_val
, hash
, p_last_epoch
, p_least_descendant_ep
, p_parent_label_len
, p_parent_label_val
, p_node_type
, p_left_child_len
, p_left_child_label_val
, p_right_child_len
, p_right_child_label_val
, p_hash
)
VALUES (
source.label_len
, source.label_val
, source.last_epoch
, source.least_descendant_ep
, source.parent_label_len
, source.parent_label_val
, source.node_type
, source.left_child_len
, source.left_child_label_val
, source.right_child_len
, source.right_child_label_val
, source.hash
, source.p_last_epoch
, source.p_least_descendant_ep
, source.p_parent_label_len
, source.p_parent_label_val
, source.p_node_type
, source.p_left_child_len
, source.p_left_child_label_val
, source.p_right_child_len
, source.p_right_child_label_val
, source.p_hash
);
"#,
TempTable::HistoryTreeNodes.to_string()
),
StorageType::ValueState => format!(
r#"
MERGE INTO dbo.{TABLE_VALUES} AS target
USING {} AS source
ON target.raw_label = source.raw_label AND target.epoch = source.epoch
WHEN MATCHED THEN
UPDATE SET
target.[version] = source.[version],
target.node_label_val = source.node_label_val,
target.node_label_len = source.node_label_len,
target.[data] = source.[data]
WHEN NOT MATCHED THEN
INSERT (raw_label, epoch, [version], node_label_val, node_label_len, [data])
VALUES (source.raw_label, source.epoch, source.[version], source.node_label_val, source.node_label_len, source.[data]);
"#,
TempTable::Values.to_string()
),
}
}
fn get_statement<St: Storable>(key: &St::StorageKey) -> Result<Statement, StorageError> {
let mut params = SqlParams::new();
let sql = match St::data_type() {
StorageType::Azks => {
format!(
r#"
SELECT {} from dbo.{} LIMIT 1;
"#,
SELECT_AZKS_DATA.join(", "),
TABLE_AZKS
)
}
StorageType::TreeNode => {
let bin = St::get_full_binary_key_id(key);
// These are constructed from a safe key, they should never fail
let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin)
.expect("Failed to decode key"); // TODO: should this be an error?
params.add("label_len", Box::new(key.0.label_len as i32));
params.add("label_val", Box::new(key.0.label_val.to_vec()));
format!(
r#"
SELECT {} from dbo.{TABLE_HISTORY_TREE_NODES} WHERE [label_len] = {} AND [label_val] = {};
"#,
SELECT_HISTORY_TREE_NODE_DATA.join(", "),
params.key_for("label_len").expect("key present"),
params.key_for("label_val").expect("key present"),
)
}
StorageType::ValueState => {
let bin = St::get_full_binary_key_id(key);
// These are constructed from a safe key, they should never fail
let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error?
params.add("raw_label", Box::new(key.0.clone()));
params.add("epoch", Box::new(key.1 as i64));
format!(
r#"
SELECT {} from dbo.{TABLE_VALUES} WHERE [raw_label] = {} AND [epoch] = {};
"#,
SELECT_LABEL_DATA.join(", "),
params.key_for("raw_label").expect("key present"),
params.key_for("epoch").expect("key present"),
)
}
};
Ok(Statement::new(sql, params))
}
fn get_batch_temp_table_rows<St: Storable>(
key: &[St::StorageKey],
) -> Result<Vec<TokenRow>, StorageError> {
match St::data_type() {
StorageType::Azks => Err(StorageError::Other(
"Batch temp table rows not supported for Azks".to_string(),
)),
StorageType::TreeNode => {
let mut rows = Vec::new();
for k in key {
let bin = St::get_full_binary_key_id(k);
// These are constructed from a safe key, they should never fail
let key = TreeNodeWithPreviousValue::key_from_full_binary(&bin)
.expect("Failed to decode key");
let row = (key.0.label_len as i32, key.0.label_val.to_vec()).into_row();
rows.push(row);
}
Ok(rows)
}
StorageType::ValueState => {
let mut rows = Vec::new();
for k in key {
let bin = St::get_full_binary_key_id(k);
// These are constructed from a safe key, they should never fail
let key = ValueState::key_from_full_binary(&bin).expect("Failed to decode key"); // TODO: should this be an error?
let row = (key.0.clone(), key.1 as i64).into_row();
rows.push(row);
}
Ok(rows)
}
}
}
fn get_batch_statement<St: Storable>() -> String {
// Note: any changes to these columns need to be reflected in from_row below
match St::data_type() {
StorageType::Azks => panic!("Batch get not supported for Azks"),
StorageType::TreeNode => format!(
r#"
SELECT {}
FROM dbo.{TABLE_HISTORY_TREE_NODES} h
INNER JOIN {} t
ON h.label_len = t.label_len
AND h.label_val = t.label_val;
"#,
SELECT_HISTORY_TREE_NODE_DATA
.iter()
.map(|s| format!("h.{s}"))
.collect::<Vec<_>>()
.join(", "),
TempTable::for_ids::<St>().to_string()
),
StorageType::ValueState => format!(
r#"
SELECT {}
FROM dbo.{TABLE_VALUES} v
INNER JOIN {} t
ON v.raw_label = t.raw_label
AND v.epoch = t.epoch;
"#,
SELECT_LABEL_DATA
.iter()
.map(|s| format!("v.{s}"))
.collect::<Vec<_>>()
.join(", "),
TempTable::for_ids::<St>().to_string()
),
}
}
fn from_row<St: Storable>(row: &Row) -> Result<Self, StorageError>
where
Self: Sized,
{
match St::data_type() {
// TODO: check this
StorageType::Azks => {
let epoch: i64 = row
.get("epoch")
.ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?;
let num_nodes: i64 = row.get("num_nodes").ok_or_else(|| {
StorageError::Other("num_nodes is NULL or missing".to_string())
})?;
let azks = DbRecord::build_azks(epoch as u64, num_nodes as u64);
Ok(DbRecord::Azks(azks))
}
StorageType::TreeNode => {
let label_len: i32 = row.get("label_len").ok_or_else(|| {
StorageError::Other("label_len is NULL or missing".to_string())
})?;
let label_val: &[u8] = row.get("label_val").ok_or_else(|| {
StorageError::Other("label_val is NULL or missing".to_string())
})?;
let last_epoch: i64 = row.get("last_epoch").ok_or_else(|| {
StorageError::Other("last_epoch is NULL or missing".to_string())
})?;
let least_descendant_ep: i64 = row.get("least_descendant_ep").ok_or_else(|| {
StorageError::Other("least_descendant_ep is NULL or missing".to_string())
})?;
let parent_label_len: i32 = row.get("parent_label_len").ok_or_else(|| {
StorageError::Other("parent_label_len is NULL or missing".to_string())
})?;
let parent_label_val: &[u8] = row.get("parent_label_val").ok_or_else(|| {
StorageError::Other("parent_label_val is NULL or missing".to_string())
})?;
let node_type: i16 = row.get("node_type").ok_or_else(|| {
StorageError::Other("node_type is NULL or missing".to_string())
})?;
let left_child_len: Option<i32> = row.get("left_child_len");
let left_child_label_val: Option<&[u8]> = row.get("left_child_label_val");
let right_child_len: Option<i32> = row.get("right_child_len");
let right_child_label_val: Option<&[u8]> = row.get("right_child_len");
let hash: &[u8] = row
.get("hash")
.ok_or_else(|| StorageError::Other("hash is NULL or missing".to_string()))?;
let p_last_epoch: Option<i64> = row.get("p_last_epoch");
let p_least_descendant_ep: Option<i64> = row.get("p_least_descendant_ep");
let p_parent_label_len: Option<i32> = row.get("p_parent_label_len");
let p_parent_label_val: Option<&[u8]> = row.get("p_parent_label_val");
let p_node_type: Option<i16> = row.get("p_node_type");
let p_left_child_len: Option<i32> = row.get("p_left_child_len");
let p_left_child_label_val: Option<&[u8]> = row.get("p_left_child_label_val");
let p_right_child_len: Option<i32> = row.get("p_right_child_len");
let p_right_child_label_val: Option<&[u8]> = row.get("p_right_child_label_val");
let p_hash: Option<&[u8]> = row.get("p_hash");
// Make child nodes
fn optional_child_label(
child_val: Option<&[u8]>,
child_len: Option<i32>,
) -> Result<Option<NodeLabel>, StorageError> {
match (child_val, child_len) {
(Some(val), Some(len)) => {
let val_vec: Vec<u8> = val.to_vec().try_into().map_err(|_| {
StorageError::Other("child_val has incorrect length".to_string())
})?;
Ok(Some(NodeLabel::new(
val_vec.try_into().map_err(|_| {
StorageError::Other(
"child_val has incorrect length".to_string(),
)
})?,
len as u32,
)))
}
_ => Ok(None),
}
}
let left_child = optional_child_label(left_child_label_val, left_child_len)?;
let right_child = optional_child_label(right_child_label_val, right_child_len)?;
let p_left_child = optional_child_label(p_left_child_label_val, p_left_child_len)?;
let p_right_child =
optional_child_label(p_right_child_label_val, p_right_child_len)?;
let massaged_p_parent_label_val: Option<[u8; 32]> = match p_parent_label_val {
Some(v) => Some(v.to_vec().try_into().map_err(|_| {
StorageError::Other("p_parent_label_val has incorrect length".to_string())
})?),
None => None,
};
let massaged_hash: akd::Digest = akd::hash::try_parse_digest(&hash.to_vec())
.map_err(|_| StorageError::Other("hash has incorrect length".to_string()))?;
let massaged_p_hash: Option<akd::Digest> = match p_hash {
Some(v) => Some(akd::hash::try_parse_digest(&v.to_vec()).map_err(|_| {
StorageError::Other("p_hash has incorrect length".to_string())
})?),
None => None,
};
let node = DbRecord::build_tree_node_with_previous_value(
label_val.try_into().map_err(|_| {
StorageError::Other("label_val has incorrect length".to_string())
})?,
label_len as u32,
last_epoch as u64,
least_descendant_ep as u64,
parent_label_val.try_into().map_err(|_| {
StorageError::Other("parent_label_val has incorrect length".to_string())
})?,
parent_label_len as u32,
node_type as u8,
left_child,
right_child,
massaged_hash,
p_last_epoch.map(|v| v as u64),
p_least_descendant_ep.map(|v| v as u64),
massaged_p_parent_label_val,
p_parent_label_len.map(|v| v as u32),
p_node_type.map(|v| v as u8),
p_left_child,
p_right_child,
massaged_p_hash,
);
Ok(DbRecord::TreeNode(node))
}
StorageType::ValueState => Ok(DbRecord::ValueState(crate::tables::values::from_row(row)?)),
}
}
fn into_row(&self) -> Result<TokenRow, StorageError> {
match &self {
DbRecord::Azks(azks) => {
let row = (
1u8, // constant key
azks.latest_epoch as i64,
azks.num_nodes as i64,
)
.into_row();
Ok(row)
}
DbRecord::TreeNode(node) => {
let mut row = TokenRow::new();
row.push(ColumnData::I32(Some(node.label.label_len as i32)));
row.push(ColumnData::Binary(Some(
node.label.label_val.to_vec().into(),
)));
// Latest node values
row.push(ColumnData::I64(Some(node.latest_node.last_epoch as i64)));
row.push(ColumnData::I64(Some(
node.latest_node.min_descendant_epoch as i64,
)));
row.push(ColumnData::I32(Some(
node.latest_node.parent.label_len as i32,
)));
row.push(ColumnData::Binary(Some(
node.latest_node.parent.label_val.to_vec().into(),
)));
row.push(ColumnData::I16(Some(node.latest_node.node_type as i16)));
match &node.latest_node.left_child {
Some(lc) => {
row.push(ColumnData::I32(Some(lc.label_len as i32)));
row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into())));
}
None => {
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
}
}
match &node.latest_node.right_child {
Some(rc) => {
row.push(ColumnData::I32(Some(rc.label_len as i32)));
row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into())));
}
None => {
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
}
}
row.push(ColumnData::Binary(Some(
node.latest_node.hash.0.to_vec().into(),
)));
// Previous node values
match &node.previous_node {
Some(p) => {
row.push(ColumnData::I64(Some(p.last_epoch as i64)));
row.push(ColumnData::I64(Some(p.min_descendant_epoch as i64)));
row.push(ColumnData::I32(Some(p.label.label_len as i32)));
row.push(ColumnData::Binary(Some(p.label.label_val.to_vec().into())));
row.push(ColumnData::I16(Some(p.node_type as i16)));
match &p.left_child {
Some(lc) => {
row.push(ColumnData::I32(Some(lc.label_len as i32)));
row.push(ColumnData::Binary(Some(lc.label_val.to_vec().into())));
}
None => {
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
}
}
match &p.right_child {
Some(rc) => {
row.push(ColumnData::I32(Some(rc.label_len as i32)));
row.push(ColumnData::Binary(Some(rc.label_val.to_vec().into())));
}
None => {
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
}
}
row.push(ColumnData::Binary(Some(p.hash.0.to_vec().into())));
}
None => {
// Node Values
row.push(ColumnData::I64(None));
row.push(ColumnData::I64(None));
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
row.push(ColumnData::I16(None));
// Left child
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
// Right child
row.push(ColumnData::I32(None));
row.push(ColumnData::Binary(None));
// Hash
row.push(ColumnData::Binary(None));
}
}
Ok(row)
}
DbRecord::ValueState(state) => {
let row = (
state.get_id().0.clone(),
state.epoch as i64,
state.version as i64,
state.label.label_val.to_vec(),
state.label.label_len as i64,
state.value.0.clone(),
)
.into_row();
Ok(row)
}
}
}
}

View File

@@ -0,0 +1,177 @@
use ms_database::ToSql;
pub struct SqlParam {
/// The parameter key (e.g., "@P1", "@P2")
pub key: String,
/// The column name this parameter maps to
column: String,
pub data: Box<dyn ToSql>,
}
impl SqlParam {
fn column(&self) -> String {
SqlParam::wrap_in_brackets(&self.column)
}
fn wrap_in_brackets(s: &str) -> String {
// ensure column names are wrapped in brackets for SQL Server
let trimmed = s.trim();
let starts_with_bracket = trimmed.starts_with('[');
let ends_with_bracket = trimmed.ends_with(']');
match (starts_with_bracket, ends_with_bracket) {
(true, true) => trimmed.to_string(),
(true, false) => format!("{}]", trimmed),
(false, true) => format!("[{}", trimmed),
(false, false) => format!("[{}]", trimmed),
}
}
}
pub(crate) struct SqlParams {
params: Vec<Box<SqlParam>>,
}
impl SqlParams {
pub fn new() -> Self {
Self { params: Vec::new() }
}
pub fn add(&mut self, column: impl Into<String>, value: Box<dyn ToSql>) {
self.params.push(Box::new(SqlParam {
key: format!("@P{}", self.params.len() + 1),
column: column.into(),
data: value,
}));
}
pub fn keys(&self) -> Vec<String> {
self.params.iter().map(|p| p.key.clone()).collect()
}
pub fn keys_as_columns(&self) -> Vec<String> {
self.params
.iter()
.map(|p| format!("{} AS {}", p.key, p.column))
.collect()
}
pub fn keys_except_columns(&self, excludes: Vec<&str>) -> Vec<String> {
self.params
.iter()
.filter(|p| !excludes.contains(&p.column.as_str()))
.map(|p| p.key.clone())
.collect()
}
pub fn key_for(&self, column: &str) -> Option<String> {
self.params
.iter()
.find(|p| p.column == column)
.map(|p| p.key.clone())
}
pub fn columns(&self) -> Vec<String> {
self.params
.iter()
.map(|p| p.column())
.collect()
}
pub fn columns_except(&self, excludes: Vec<&str>) -> Vec<String> {
self.params
.iter()
.filter(|p| !excludes.contains(&p.column.as_str()))
.map(|p| p.column())
.collect()
}
pub fn columns_prefix_with(&self, prefix: &str) -> Vec<String> {
self.params
.iter()
.map(|p| format!("{}{}", prefix, p.column()))
.collect()
}
pub fn set_columns_equal(&self, assign_prefix: &str, source_prefix: &str) -> Vec<String> {
self.params
.iter()
.map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column()))
.collect()
}
pub fn set_columns_equal_except(
&self,
assign_prefix: &str,
source_prefix: &str,
excludes: Vec<&str>,
) -> Vec<String> {
self.params
.iter()
.filter(|p| !excludes.contains(&p.column.as_str()))
.map(|p| format!("{}{} = {}{}", assign_prefix, p.column(), source_prefix, p.column()))
.collect()
}
pub fn values(&self) -> Vec<&(dyn ToSql)> {
self.params
.iter()
.map(|b| b.data.as_ref() as &(dyn ToSql))
.collect()
}
}
pub struct VecStringBuilder<'a> {
strings: Vec<String>,
ops: Vec<StringBuilderOperation<'a>>,
}
enum StringBuilderOperation<'a> {
StringOperation(Box<dyn Fn(String) -> String + 'a>),
VectorOperation(Box<dyn Fn(Vec<String>) -> Vec<String> + 'a>),
}
impl<'a> VecStringBuilder<'a> {
pub fn new(strings: Vec<String>) -> Self {
Self {
strings,
ops: Vec::new(),
}
}
pub fn map<F>(&mut self, op: F)
where
F: Fn(String) -> String + 'static,
{
self.ops.push(StringBuilderOperation::StringOperation(Box::new(op)));
}
pub fn build(self) -> Vec<String> {
let mut result = self.strings;
for op in self.ops {
match op {
StringBuilderOperation::StringOperation(f) => {
result = result.into_iter().map(f.as_ref()).collect();
}
StringBuilderOperation::VectorOperation(f) => {
result = f(result);
}
}
}
result
}
pub fn join(self, sep: &str) -> String {
self.build().join(sep)
}
pub fn except(&mut self, excludes: Vec<&'a str>) {
self.ops.push(StringBuilderOperation::VectorOperation(Box::new(
move |vec: Vec<String>| {
vec.into_iter()
.filter(|s| !excludes.contains(&s.as_str()))
.collect()
},
)));
}
}

View File

@@ -0,0 +1 @@
pub(crate) mod values;

View File

@@ -0,0 +1,190 @@
use akd::{
errors::StorageError,
storage::types::{DbRecord, ValueState, ValueStateRetrievalFlag},
AkdLabel, AkdValue,
};
use ms_database::Row;
use crate::{migrations::TABLE_VALUES, ms_sql_storable::{QueryStatement, Statement}, sql_params::SqlParams};
pub fn get_all(raw_label: &AkdLabel) -> QueryStatement<ValueState> {
let mut params = SqlParams::new();
// the raw vector is the key for value storage
params.add("raw_label", Box::new(raw_label.0.clone()));
let sql = format!(
r#"
SELECT raw_label, epoch, version, node_label_val, node_label_len, data
FROM {}
WHERE raw_label = {}
"#,
TABLE_VALUES, params
.key_for("raw_label")
.expect("raw_label was added to the params list")
);
QueryStatement::new(sql, params, from_row)
}
pub fn get_by_flag(raw_label: &AkdLabel, flag: ValueStateRetrievalFlag) -> QueryStatement<ValueState> {
let mut params = SqlParams::new();
params.add("raw_label", Box::new(raw_label.0.clone()));
match flag {
ValueStateRetrievalFlag::SpecificEpoch(epoch)
| ValueStateRetrievalFlag::LeqEpoch(epoch) => params.add("epoch", Box::new(epoch as i64)),
ValueStateRetrievalFlag::SpecificVersion(version) => {
params.add("version", Box::new(version as i64))
}
_ => {}
}
let mut sql = format!(
r#"
SELECT TOP(1) raw_label, epoch, version, node_label_val, node_label_len, data
FROM {}
WHERE raw_label = {}
"#,
TABLE_VALUES,
params
.key_for("raw_label")
.expect("raw_label was added to the params list")
);
match flag {
ValueStateRetrievalFlag::SpecificEpoch(_) => {
sql.push_str(&format!(
" AND epoch = {}",
&params
.key_for("epoch")
.expect("epoch was added to the params list")
));
}
ValueStateRetrievalFlag::SpecificVersion(_) => {
sql.push_str(&format!(
" AND version = {}",
&params
.key_for("version")
.expect("version was added to the params list")
));
}
ValueStateRetrievalFlag::MaxEpoch => {
sql.push_str(" ORDER BY epoch DESC ");
}
ValueStateRetrievalFlag::MinEpoch => {
sql.push_str(" ORDER BY epoch ASC ");
}
ValueStateRetrievalFlag::LeqEpoch(_) => {
sql.push_str(&format!(
r#"
AND epoch <= {}
ORDER BY epoch DESC
"#,
&params
.key_for("epoch")
.expect("epoch was added to the params list")
));
}
}
QueryStatement::new(sql, params, from_row)
}
pub fn get_versions_by_flag(
temp_table_name: &str,
flag: ValueStateRetrievalFlag,
) -> QueryStatement<LabelVersion> {
let mut params = SqlParams::new();
let (filter, epoch_col) = match flag {
ValueStateRetrievalFlag::SpecificVersion(version) => {
params.add("version", Box::new(version as i64));
(format!("WHERE tmp.version = {}", params.key_for("version").expect("version was added to the params list")), "tmp.epoch")
}
ValueStateRetrievalFlag::SpecificEpoch(epoch) => {
params.add("epoch", Box::new(epoch as i64));
(format!("WHERE tmp.epoch = {}", params.key_for("epoch").expect("epoch was added to the params list")), "tmp.epoch")
}
ValueStateRetrievalFlag::LeqEpoch(epoch) => {
params.add("epoch", Box::new(epoch as i64));
(format!("WHERE tmp.epoch <= {}", params.key_for("epoch").expect("epoch was added to the params list")), "MAX(tmp.epoch)")
}
ValueStateRetrievalFlag::MaxEpoch => ("".to_string(), "MAX(tmp.epoch)"),
ValueStateRetrievalFlag::MinEpoch => ("".to_string(), "MIN(tmp.epoch)"),
};
let sql = format!(
r#"
SELECT t.raw_label, t.version, t.data
FROM {TABLE_VALUES} t
INNER JOIN (
SELECT tmp.raw_label as raw_label, {} as epoch
FROM {TABLE_VALUES} tmp
INNER JOIN {temp_table_name} s ON s.raw_label = tmp.raw_label
{}
GROUP BY tmp.raw_label
) epochs on epochs.raw_label = t.raw_label AND epochs.epoch = t.epoch
"#,
epoch_col,
filter,
);
QueryStatement::new(sql, params, version_from_row)
}
pub(crate) fn from_row(row: &Row) -> Result<ValueState, StorageError> {
let raw_label: &[u8] = row
.get("raw_label")
.ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?;
let epoch: i64 = row
.get("epoch")
.ok_or_else(|| StorageError::Other("epoch is NULL or missing".to_string()))?;
let version: i64 = row
.get("version")
.ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?;
let node_label_val: &[u8] = row
.get("node_label_val")
.ok_or_else(|| StorageError::Other("node_label_val is NULL or missing".to_string()))?;
let node_label_len: i64 = row
.get("node_label_len")
.ok_or_else(|| StorageError::Other("node_label_len is NULL or missing".to_string()))?;
let data: &[u8] = row
.get("data")
.ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?;
let state = DbRecord::build_user_state(
raw_label.to_vec(),
data.to_vec(),
version as u64,
node_label_len as u32,
node_label_val
.to_vec()
.try_into()
.map_err(|_| StorageError::Other("node_label_val has incorrect length".to_string()))?,
epoch as u64,
);
Ok(state)
}
pub(crate) struct LabelVersion {
pub label: AkdLabel,
pub version: u64,
pub data: AkdValue,
}
fn version_from_row(row: &Row) -> Result<LabelVersion, StorageError> {
let raw_label: &[u8] = row
.get("raw_label")
.ok_or_else(|| StorageError::Other("raw_label is NULL or missing".to_string()))?;
let version: i64 = row
.get("version")
.ok_or_else(|| StorageError::Other("version is NULL or missing".to_string()))?;
let data: &[u8] = row
.get("data")
.ok_or_else(|| StorageError::Other("data is NULL or missing".to_string()))?;
Ok(LabelVersion {
label: AkdLabel(raw_label.to_vec()),
version: version as u64,
data: AkdValue(data.to_vec()),
})
}

View File

@@ -7,13 +7,10 @@ pub(crate) enum TempTable {
Azks,
HistoryTreeNodes,
Values,
RawLabelSearch,
}
impl TempTable {
pub fn for_ids_for(storage_type: &StorageType) -> Self {
TempTable::Ids(storage_type.clone())
}
pub fn for_ids<St: Storable>() -> Self {
TempTable::Ids(St::data_type())
}
@@ -111,6 +108,15 @@ impl TempTable {
TEMP_IDS_TABLE
),
}
TempTable::RawLabelSearch => format!(
r#"
CREATE TABLE {} (
raw_label VARBINARY(256) NOT NULL,
PRIMARY KEY (raw_label)
);
"#,
TEMP_SEARCH_LABELS_TABLE
)
}
}
}
@@ -122,6 +128,7 @@ impl ToString for TempTable {
TempTable::Azks => TEMP_AZKS_TABLE.to_string(),
TempTable::HistoryTreeNodes => TEMP_HISTORY_TREE_NODES_TABLE.to_string(),
TempTable::Values => TEMP_VALUES_TABLE.to_string(),
TempTable::RawLabelSearch => TEMP_SEARCH_LABELS_TABLE.to_string(),
}
}
}
@@ -137,6 +144,7 @@ impl From<StorageType> for TempTable {
}
pub(crate) const TEMP_IDS_TABLE: &str = "#akd_temp_ids";
pub(crate) const TEMP_SEARCH_LABELS_TABLE: &str = "#akd_temp_search_labels";
pub(crate) const TEMP_AZKS_TABLE: &str = "#akd_temp_azks";
pub(crate) const TEMP_HISTORY_TREE_NODES_TABLE: &str = "#akd_temp_history_tree_nodes";
pub(crate) const TEMP_VALUES_TABLE: &str = "#akd_temp_values";

View File

@@ -6,7 +6,7 @@ pub use pool::ConnectionManager as MsSqlConnectionManager;
pub use pool::{OnConnectError};
// re-expose tiberius types for convenience
pub use tiberius::{error::Error as MsDbError, Column, Row, ToSql};
pub use tiberius::{error::Error as MsDbError, Column, Row, FromSql, ToSql, IntoRow, TokenRow, ColumnData};
// re-expose bb8 types for convenience
pub type Pool = bb8::Pool<MsSqlConnectionManager>;