1
0
mirror of https://github.com/bitwarden/server synced 2026-01-29 07:43:22 +00:00

Use connection strings and implement connection errors

This commit is contained in:
Matt Gibson
2025-10-02 10:00:54 -07:00
parent e06807df11
commit 3965b241fa

View File

@@ -5,7 +5,17 @@ use tokio_util::compat::TokioAsyncWriteCompatExt;
pub use migrate::{Migration, MigrationError, run_pending_migrations};
use bb8::ManageConnection;
use tiberius::{AuthMethod, Client, Config};
use tiberius::{Client, Config};
#[derive(thiserror::Error, Debug)]
pub enum OnConnectError {
#[error("Config error: {0}")]
Config(#[source] tiberius::error::Error),
#[error("TCP error: {0}")]
Tcp(#[from] std::io::Error),
#[error("On Connect error: {0}")]
OnConnect(#[source] tiberius::error::Error),
}
pub struct ConnectionManager {
connection_string: String,
@@ -16,13 +26,8 @@ impl ConnectionManager {
Self { connection_string }
}
pub async fn connect(&self) -> Result<ManagedConnection, tiberius::error::Error> {
let mut config = Config::new();
config.host("localhost");
config.port(1433);
config.authentication(AuthMethod::sql_server("SA", "<YourStrong@Passw0rd>"));
config.trust_cert(); // on production, it is not a good idea to do this
pub async fn connect(&self) -> Result<ManagedConnection, OnConnectError> {
let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?;
let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
@@ -30,7 +35,7 @@ impl ConnectionManager {
// To be able to use Tokio's tcp, we're using the `compat_write` from
// the `TokioAsyncWriteCompatExt` to get a stream compatible with the
// traits from the `futures` crate.
let client = Client::connect(config, tcp.compat_write()).await?;
let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?;
Ok(ManagedConnection(client))
}
@@ -41,7 +46,7 @@ pub struct ManagedConnection(Client<Stream>);
// Transparently forward methods to the inner Client
impl ManagedConnection {
async fn execute(
pub async fn execute(
&mut self,
sql: &str,
params: &[&(dyn tiberius::ToSql)],
@@ -49,7 +54,7 @@ impl ManagedConnection {
self.0.execute(sql, params).await
}
async fn query<'a>(
pub async fn query<'a>(
&'a mut self,
sql: &str,
params: &[&(dyn tiberius::ToSql)],
@@ -57,14 +62,14 @@ impl ManagedConnection {
self.0.query(sql, params).await
}
async fn simple_query<'a>(
pub async fn simple_query<'a>(
&'a mut self,
sql: &str,
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
self.0.simple_query(sql).await
}
async fn bulk_insert<'a>(
pub async fn bulk_insert<'a>(
&'a mut self,
table: &'a str,
) -> Result<tiberius::BulkLoadRequest<'a, Stream>, tiberius::error::Error> {
@@ -78,22 +83,30 @@ impl ManagedConnection {
}
}
#[derive(thiserror::Error, Debug)]
pub enum PoolError {
#[error("Connection error: {0}")]
Connection(#[from] tiberius::error::Error),
#[error("On Connect error: {0}")]
OnConnect(#[source] OnConnectError),
#[error("Unexpected ping response: {0}")]
Ping(u8)
}
impl ManageConnection for ConnectionManager {
type Connection = ManagedConnection;
type Error = tiberius::error::Error;
type Error = PoolError;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.connect().await
self.connect().await.map_err(PoolError::OnConnect)
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
match conn.ping().await {
Ok(v) if v == 1 => Ok(()),
Ok(_) => Err(tiberius::error::Error::Protocol(
"Unexpected ping response".into(),
)),
Err(e) => Err(e),
Ok(v) => Err(PoolError::Ping(v)),
Err(e) => Err(PoolError::Connection(e)),
}
}