1
0
mirror of https://github.com/bitwarden/server synced 2026-01-28 23:36:12 +00:00

fixup sql connection pool

This commit is contained in:
Matt Gibson
2025-10-16 15:39:55 -07:00
parent f4ec8bd4a0
commit 2c971fc77e
3 changed files with 46 additions and 10 deletions

View File

@@ -12,3 +12,8 @@ keywords = ["akd", "key transparency"]
[workspace.lints.clippy]
unused_async = "deny"
unwrap_used = "deny"
[workspace.dependencies]
tokio = { version = "1.47.1", features = ["full"] }
tracing = { version = "0.1.41" }
tracing-subscriber = {version = "0.3.19" }

View File

@@ -7,12 +7,20 @@ license-file.workspace = true
keywords.workspace = true
[dependencies]
async-trait = "0.1.89"
bb8 = "0.9.0"
macros = { path = "../macros" }
thiserror = "2.0.17"
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
tokio = "1.47.1"
tokio = {workspace = true}
tokio-util = {version = "0.7.16", features = ["compat"] }
tracing = { workspace = true }
[target.'cfg(target_os = "macos")'.dependencies]
tiberius = { version = "0.12.3", default-features = false, features = ["chrono", "tokio", "rustls"] }
[target.'cfg(not(target_os = "macos"))'.workspace.dependencies]
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
[lints]
workspace = true

View File

@@ -5,6 +5,7 @@ use tokio_util::compat::TokioAsyncWriteCompatExt;
use bb8::ManageConnection;
use tiberius::{Client, Config};
use tracing::{info, instrument, trace};
#[derive(thiserror::Error, Debug)]
pub enum OnConnectError {
@@ -23,19 +24,28 @@ pub struct ConnectionManager {
impl ConnectionManager {
pub fn new(connection_string: String) -> Self {
Self { connection_string, is_healthy: RwLock::new(true) }
Self {
connection_string,
is_healthy: RwLock::new(true),
}
}
#[instrument(skip(self), level = "info")]
pub async fn connect(&self) -> Result<ManagedConnection, OnConnectError> {
let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?;
let config =
Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?;
info!(config = ?config, "Connecting");
let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
// To be able to use Tokio's tcp, we're using the `compat_write` from
// the `TokioAsyncWriteCompatExt` to get a stream compatible with the
// traits from the `futures` crate.
let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?;
let client = Client::connect(config, tcp.compat_write())
.await
.map_err(OnConnectError::OnConnect)?;
info!("Successfully connected");
Ok(ManagedConnection(client))
}
@@ -57,6 +67,7 @@ impl ManagedConnection {
sql: &str,
params: &[&(dyn tiberius::ToSql)],
) -> Result<tiberius::ExecuteResult, tiberius::error::Error> {
trace!(%sql, "Executing SQL");
self.0.execute(sql, params).await
}
@@ -65,6 +76,7 @@ impl ManagedConnection {
sql: &str,
params: &[&(dyn tiberius::ToSql)],
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
trace!(%sql, "Querying SQL");
self.0.query(sql, params).await
}
@@ -72,18 +84,26 @@ impl ManagedConnection {
&'a mut self,
sql: &str,
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
trace!(%sql, "Simple querying SQL");
self.0.simple_query(sql).await
}
pub async fn bulk_insert<'a>(
&'a mut self,
table: &'a str,
) -> Result<tiberius::BulkLoadRequest<'a, Stream>, tiberius::error::Error> {
trace!(%table, "Starting bulk insert");
self.0.bulk_insert(&table).await
}
async fn ping(&mut self) -> Result<u8, tiberius::error::Error> {
let row = self.0.simple_query("SELECT 1").await?.into_first_result().await?;
async fn ping(&mut self) -> Result<i32, tiberius::error::Error> {
let row = self
.0
.simple_query("SELECT 1")
.await?
.into_first_result()
.await?;
info!(?row, "Ping response");
let value = row[0].get(0).expect("value is present");
Ok(value)
}
@@ -96,7 +116,7 @@ pub enum PoolError {
#[error("On Connect error: {0}")]
OnConnect(#[source] OnConnectError),
#[error("Unexpected ping response: {0}")]
Ping(u8)
Ping(i32),
}
impl ManageConnection for ConnectionManager {
@@ -117,6 +137,9 @@ impl ManageConnection for ConnectionManager {
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
self.is_healthy.read().expect("poisoned is_healthy lock").clone()
self.is_healthy
.read()
.expect("poisoned is_healthy lock")
.clone()
}
}