diff --git a/akd/crates/ms_database/src/lib.rs b/akd/crates/ms_database/src/lib.rs index 18af284567..e399ef90e9 100644 --- a/akd/crates/ms_database/src/lib.rs +++ b/akd/crates/ms_database/src/lib.rs @@ -5,7 +5,17 @@ use tokio_util::compat::TokioAsyncWriteCompatExt; pub use migrate::{Migration, MigrationError, run_pending_migrations}; use bb8::ManageConnection; -use tiberius::{AuthMethod, Client, Config}; +use tiberius::{Client, Config}; + +#[derive(thiserror::Error, Debug)] +pub enum OnConnectError { + #[error("Config error: {0}")] + Config(#[source] tiberius::error::Error), + #[error("TCP error: {0}")] + Tcp(#[from] std::io::Error), + #[error("On Connect error: {0}")] + OnConnect(#[source] tiberius::error::Error), +} pub struct ConnectionManager { connection_string: String, @@ -16,13 +26,8 @@ impl ConnectionManager { Self { connection_string } } - pub async fn connect(&self) -> Result { - let mut config = Config::new(); - - config.host("localhost"); - config.port(1433); - config.authentication(AuthMethod::sql_server("SA", "")); - config.trust_cert(); // on production, it is not a good idea to do this + pub async fn connect(&self) -> Result { + let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?; let tcp = TcpStream::connect(config.get_addr()).await?; tcp.set_nodelay(true)?; @@ -30,7 +35,7 @@ impl ConnectionManager { // To be able to use Tokio's tcp, we're using the `compat_write` from // the `TokioAsyncWriteCompatExt` to get a stream compatible with the // traits from the `futures` crate. - let client = Client::connect(config, tcp.compat_write()).await?; + let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?; Ok(ManagedConnection(client)) } @@ -41,7 +46,7 @@ pub struct ManagedConnection(Client); // Transparently forward methods to the inner Client impl ManagedConnection { - async fn execute( + pub async fn execute( &mut self, sql: &str, params: &[&(dyn tiberius::ToSql)], @@ -49,7 +54,7 @@ impl ManagedConnection { self.0.execute(sql, params).await } - async fn query<'a>( + pub async fn query<'a>( &'a mut self, sql: &str, params: &[&(dyn tiberius::ToSql)], @@ -57,14 +62,14 @@ impl ManagedConnection { self.0.query(sql, params).await } - async fn simple_query<'a>( + pub async fn simple_query<'a>( &'a mut self, sql: &str, ) -> Result, tiberius::error::Error> { self.0.simple_query(sql).await } - async fn bulk_insert<'a>( + pub async fn bulk_insert<'a>( &'a mut self, table: &'a str, ) -> Result, tiberius::error::Error> { @@ -78,22 +83,30 @@ impl ManagedConnection { } } +#[derive(thiserror::Error, Debug)] +pub enum PoolError { + #[error("Connection error: {0}")] + Connection(#[from] tiberius::error::Error), + #[error("On Connect error: {0}")] + OnConnect(#[source] OnConnectError), + #[error("Unexpected ping response: {0}")] + Ping(u8) +} + impl ManageConnection for ConnectionManager { type Connection = ManagedConnection; - type Error = tiberius::error::Error; + type Error = PoolError; async fn connect(&self) -> Result { - self.connect().await + self.connect().await.map_err(PoolError::OnConnect) } async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { match conn.ping().await { Ok(v) if v == 1 => Ok(()), - Ok(_) => Err(tiberius::error::Error::Protocol( - "Unexpected ping response".into(), - )), - Err(e) => Err(e), + Ok(v) => Err(PoolError::Ping(v)), + Err(e) => Err(PoolError::Connection(e)), } }