mirror of
https://github.com/bitwarden/server
synced 2026-02-28 10:23:24 +00:00
fixup sql connection pool
This commit is contained in:
@@ -5,6 +5,7 @@ use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||
|
||||
use bb8::ManageConnection;
|
||||
use tiberius::{Client, Config};
|
||||
use tracing::{info, instrument, trace};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum OnConnectError {
|
||||
@@ -23,19 +24,28 @@ pub struct ConnectionManager {
|
||||
|
||||
impl ConnectionManager {
|
||||
pub fn new(connection_string: String) -> Self {
|
||||
Self { connection_string, is_healthy: RwLock::new(true) }
|
||||
Self {
|
||||
connection_string,
|
||||
is_healthy: RwLock::new(true),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self), level = "info")]
|
||||
pub async fn connect(&self) -> Result<ManagedConnection, OnConnectError> {
|
||||
let config = Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?;
|
||||
let config =
|
||||
Config::from_ado_string(&self.connection_string).map_err(OnConnectError::Config)?;
|
||||
|
||||
info!(config = ?config, "Connecting");
|
||||
let tcp = TcpStream::connect(config.get_addr()).await?;
|
||||
tcp.set_nodelay(true)?;
|
||||
|
||||
// To be able to use Tokio's tcp, we're using the `compat_write` from
|
||||
// the `TokioAsyncWriteCompatExt` to get a stream compatible with the
|
||||
// traits from the `futures` crate.
|
||||
let client = Client::connect(config, tcp.compat_write()).await.map_err(OnConnectError::OnConnect)?;
|
||||
let client = Client::connect(config, tcp.compat_write())
|
||||
.await
|
||||
.map_err(OnConnectError::OnConnect)?;
|
||||
info!("Successfully connected");
|
||||
|
||||
Ok(ManagedConnection(client))
|
||||
}
|
||||
@@ -57,6 +67,7 @@ impl ManagedConnection {
|
||||
sql: &str,
|
||||
params: &[&(dyn tiberius::ToSql)],
|
||||
) -> Result<tiberius::ExecuteResult, tiberius::error::Error> {
|
||||
trace!(%sql, "Executing SQL");
|
||||
self.0.execute(sql, params).await
|
||||
}
|
||||
|
||||
@@ -65,6 +76,7 @@ impl ManagedConnection {
|
||||
sql: &str,
|
||||
params: &[&(dyn tiberius::ToSql)],
|
||||
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
|
||||
trace!(%sql, "Querying SQL");
|
||||
self.0.query(sql, params).await
|
||||
}
|
||||
|
||||
@@ -72,18 +84,26 @@ impl ManagedConnection {
|
||||
&'a mut self,
|
||||
sql: &str,
|
||||
) -> Result<tiberius::QueryStream<'a>, tiberius::error::Error> {
|
||||
trace!(%sql, "Simple querying SQL");
|
||||
self.0.simple_query(sql).await
|
||||
}
|
||||
|
||||
|
||||
pub async fn bulk_insert<'a>(
|
||||
&'a mut self,
|
||||
table: &'a str,
|
||||
) -> Result<tiberius::BulkLoadRequest<'a, Stream>, tiberius::error::Error> {
|
||||
trace!(%table, "Starting bulk insert");
|
||||
self.0.bulk_insert(&table).await
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<u8, tiberius::error::Error> {
|
||||
let row = self.0.simple_query("SELECT 1").await?.into_first_result().await?;
|
||||
async fn ping(&mut self) -> Result<i32, tiberius::error::Error> {
|
||||
let row = self
|
||||
.0
|
||||
.simple_query("SELECT 1")
|
||||
.await?
|
||||
.into_first_result()
|
||||
.await?;
|
||||
info!(?row, "Ping response");
|
||||
let value = row[0].get(0).expect("value is present");
|
||||
Ok(value)
|
||||
}
|
||||
@@ -96,7 +116,7 @@ pub enum PoolError {
|
||||
#[error("On Connect error: {0}")]
|
||||
OnConnect(#[source] OnConnectError),
|
||||
#[error("Unexpected ping response: {0}")]
|
||||
Ping(u8)
|
||||
Ping(i32),
|
||||
}
|
||||
|
||||
impl ManageConnection for ConnectionManager {
|
||||
@@ -117,6 +137,9 @@ impl ManageConnection for ConnectionManager {
|
||||
}
|
||||
|
||||
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
|
||||
self.is_healthy.read().expect("poisoned is_healthy lock").clone()
|
||||
self.is_healthy
|
||||
.read()
|
||||
.expect("poisoned is_healthy lock")
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user