diff --git a/akd/crates/ms_database/src/lib.rs b/akd/crates/ms_database/src/lib.rs index e399ef90e9..ce93b8e5b1 100644 --- a/akd/crates/ms_database/src/lib.rs +++ b/akd/crates/ms_database/src/lib.rs @@ -1,118 +1,15 @@ mod migrate; - -use tokio::net::TcpStream; -use tokio_util::compat::TokioAsyncWriteCompatExt; - pub use migrate::{Migration, MigrationError, run_pending_migrations}; -use bb8::ManageConnection; -use tiberius::{Client, Config}; -#[derive(thiserror::Error, Debug)] -pub enum OnConnectError { - #[error("Config error: {0}")] - Config(#[source] tiberius::error::Error), - #[error("TCP error: {0}")] - Tcp(#[from] std::io::Error), - #[error("On Connect error: {0}")] - OnConnect(#[source] tiberius::error::Error), -} +mod pool; +pub use pool::ConnectionManager as MsSqlConnectionManager; -pub struct ConnectionManager { - connection_string: String, -} +// re-expose tiberius types for convenience +pub use tiberius::{error, Column, Row, ToSql}; -impl ConnectionManager { - pub fn new(connection_string: String) -> Self { - Self { connection_string } - } +// re-expose bb8 types for convenience +pub type Pool = bb8::Pool; +pub type PooledConnection<'a> = bb8::PooledConnection<'a, MsSqlConnectionManager>; - pub async fn connect(&self) -> Result { - let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?; - - let tcp = TcpStream::connect(config.get_addr()).await?; - tcp.set_nodelay(true)?; - - // To be able to use Tokio's tcp, we're using the `compat_write` from - // the `TokioAsyncWriteCompatExt` to get a stream compatible with the - // traits from the `futures` crate. - let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?; - - Ok(ManagedConnection(client)) - } -} - -type Stream = tokio_util::compat::Compat; -pub struct ManagedConnection(Client); - -// Transparently forward methods to the inner Client -impl ManagedConnection { - pub async fn execute( - &mut self, - sql: &str, - params: &[&(dyn tiberius::ToSql)], - ) -> Result { - self.0.execute(sql, params).await - } - - pub async fn query<'a>( - &'a mut self, - sql: &str, - params: &[&(dyn tiberius::ToSql)], - ) -> Result, tiberius::error::Error> { - self.0.query(sql, params).await - } - - pub async fn simple_query<'a>( - &'a mut self, - sql: &str, - ) -> Result, tiberius::error::Error> { - self.0.simple_query(sql).await - } - - pub async fn bulk_insert<'a>( - &'a mut self, - table: &'a str, - ) -> Result, tiberius::error::Error> { - self.0.bulk_insert(&table).await - } - - async fn ping(&mut self) -> Result { - let row = self.0.simple_query("SELECT 1").await?.into_first_result().await?; - let value = row[0].get(0).expect("value is present"); - Ok(value) - } -} - -#[derive(thiserror::Error, Debug)] -pub enum PoolError { - #[error("Connection error: {0}")] - Connection(#[from] tiberius::error::Error), - #[error("On Connect error: {0}")] - OnConnect(#[source] OnConnectError), - #[error("Unexpected ping response: {0}")] - Ping(u8) -} - -impl ManageConnection for ConnectionManager { - type Connection = ManagedConnection; - - type Error = PoolError; - - async fn connect(&self) -> Result { - self.connect().await.map_err(PoolError::OnConnect) - } - - async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { - match conn.ping().await { - Ok(v) if v == 1 => Ok(()), - Ok(v) => Err(PoolError::Ping(v)), - Err(e) => Err(PoolError::Connection(e)), - } - } - - fn has_broken(&self, _conn: &mut Self::Connection) -> bool { - // We don't have a good way to determine this sync. r2d2 (which bb8 is based on) recommends - // always returning false here and relying on `is_valid` to catch broken connections. - false - } -} +// re-expose macros for convenience +pub use macros::{load_migrations, migration}; diff --git a/akd/crates/ms_database/src/migrate.rs b/akd/crates/ms_database/src/migrate.rs index 19c6a32bdc..42719a41e7 100644 --- a/akd/crates/ms_database/src/migrate.rs +++ b/akd/crates/ms_database/src/migrate.rs @@ -1,6 +1,6 @@ use tiberius::{error}; -use crate::{ManagedConnection}; +use crate::pool::ManagedConnection; type Result = std::result::Result; diff --git a/akd/crates/ms_database/src/pool.rs b/akd/crates/ms_database/src/pool.rs new file mode 100644 index 0000000000..6e88b6735b --- /dev/null +++ b/akd/crates/ms_database/src/pool.rs @@ -0,0 +1,123 @@ +use std::sync::RwLock; + +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +use bb8::ManageConnection; +use tiberius::{Client, Config}; + + +#[derive(thiserror::Error, Debug)] +pub enum OnConnectError { + #[error("Config error: {0}")] + Config(#[source] tiberius::error::Error), + #[error("TCP error: {0}")] + Tcp(#[from] std::io::Error), + #[error("On Connect error: {0}")] + OnConnect(#[source] tiberius::error::Error), +} + +pub struct ConnectionManager { + connection_string: String, + is_healthy: RwLock, +} + +impl ConnectionManager { + pub fn new(connection_string: String) -> Self { + Self { connection_string, is_healthy: RwLock::new(true) } + } + + pub async fn connect(&self) -> Result { + let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?; + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + // To be able to use Tokio's tcp, we're using the `compat_write` from + // the `TokioAsyncWriteCompatExt` to get a stream compatible with the + // traits from the `futures` crate. + let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?; + + Ok(ManagedConnection(client)) + } + + /// Mark the pool as unhealthy. This is used to indicate that a connection should be replaced. + pub async fn set_unhealthy(&self) { + let mut healthy = self.is_healthy.write().expect("poisoned is_healthy lock"); + *healthy = false; + } +} + +type Stream = tokio_util::compat::Compat; +pub struct ManagedConnection(Client); + +// Transparently forward methods to the inner Client +impl ManagedConnection { + pub async fn execute( + &mut self, + sql: &str, + params: &[&(dyn tiberius::ToSql)], + ) -> Result { + self.0.execute(sql, params).await + } + + pub async fn query<'a>( + &'a mut self, + sql: &str, + params: &[&(dyn tiberius::ToSql)], + ) -> Result, tiberius::error::Error> { + self.0.query(sql, params).await + } + + pub async fn simple_query<'a>( + &'a mut self, + sql: &str, + ) -> Result, tiberius::error::Error> { + self.0.simple_query(sql).await + } + + pub async fn bulk_insert<'a>( + &'a mut self, + table: &'a str, + ) -> Result, tiberius::error::Error> { + self.0.bulk_insert(&table).await + } + + async fn ping(&mut self) -> Result { + let row = self.0.simple_query("SELECT 1").await?.into_first_result().await?; + let value = row[0].get(0).expect("value is present"); + Ok(value) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum PoolError { + #[error("Connection error: {0}")] + Connection(#[from] tiberius::error::Error), + #[error("On Connect error: {0}")] + OnConnect(#[source] OnConnectError), + #[error("Unexpected ping response: {0}")] + Ping(u8) +} + +impl ManageConnection for ConnectionManager { + type Connection = ManagedConnection; + + type Error = PoolError; + + async fn connect(&self) -> Result { + self.connect().await.map_err(PoolError::OnConnect) + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + match conn.ping().await { + Ok(v) if v == 1 => Ok(()), + Ok(v) => Err(PoolError::Ping(v)), + Err(e) => Err(PoolError::Connection(e)), + } + } + + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { + self.is_healthy.read().expect("poisoned is_healthy lock").clone() + } +}