1
0
mirror of https://github.com/bitwarden/server synced 2025-12-25 04:33:26 +00:00

Implemented some changes recommended by Claude

This commit is contained in:
Mark Kincaid
2025-11-05 09:40:36 -08:00
parent bb30b549f2
commit c99a6d1a5a
11 changed files with 784 additions and 125 deletions

View File

@@ -20,7 +20,7 @@ public static class MigrationSettingsFactory
.SetBasePath(Directory.GetCurrentDirectory())
.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true)
.AddJsonFile($"appsettings.{Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") ?? "Production"}.json", optional: true, reloadOnChange: true)
.AddUserSecrets("bitwarden-Api")
.AddUserSecrets("bitwarden-Api") // Load user secrets from the API project
.AddEnvironmentVariables();
var configuration = configBuilder.Build();

View File

@@ -290,7 +290,7 @@ public class CsvHandler(CsvSettings settings, ILogger<CsvHandler> logger)
if (string.IsNullOrEmpty(value))
{
processedRow[i] = null!;
processedRow[i] = DBNull.Value;
}
else if (specialColumns.Contains(colName))
{

View File

@@ -5,6 +5,9 @@ using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// MariaDB database importer that handles schema creation, data import, and foreign key management.
/// </summary>
public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> logger) : IDatabaseImporter
{
private readonly ILogger<MariaDbImporter> _logger = logger;
@@ -14,16 +17,24 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private MySqlConnection? _connection;
private bool _disposed = false;
/// <summary>
/// Connects to the MariaDB database.
/// </summary>
public bool Connect()
{
try
{
var connectionString = $"Server={_host};Port={_port};Database={_database};" +
$"Uid={_username};Pwd={_password};" +
$"ConnectionTimeout=30;CharSet=utf8mb4;AllowLoadLocalInfile=true;MaxPoolSize=100;";
// Build connection string with redacted password for safe logging
var safeConnectionString = $"Server={_host};Port={_port};Database={_database};" +
$"Uid={_username};Pwd={DbSeederConstants.REDACTED_PASSWORD};" +
$"ConnectionTimeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};" +
$"CharSet=utf8mb4;AllowLoadLocalInfile=false;MaxPoolSize={DbSeederConstants.DEFAULT_MAX_POOL_SIZE};";
_connection = new MySqlConnection(connectionString);
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
_connection = new MySqlConnection(actualConnectionString);
_connection.Open();
_logger.LogInformation("Connected to MariaDB: {Host}:{Port}/{Database}", _host, _port, _database);
@@ -47,6 +58,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Creates a table in MariaDB from the provided schema definition.
/// </summary>
public bool CreateTableFromSchema(
string tableName,
List<string> columns,
@@ -58,11 +72,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var mariaColumns = new List<string>();
foreach (var colName in columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var mariaType = ConvertSqlServerTypeToMariaDB(sqlServerType, specialColumns.Contains(colName));
mariaColumns.Add($"`{colName}` {mariaType}");
@@ -122,15 +140,20 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Imports data into a MariaDB table using batched INSERT statements.
/// </summary>
public bool ImportData(
string tableName,
List<string> columns,
List<object[]> data,
int batchSize = 1000)
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -146,7 +169,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
return false;
}
// Filter columns
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
@@ -154,6 +176,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
{
if (actualColumns.Contains(columns[i]))
{
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
@@ -217,14 +240,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
transaction.Commit();
totalImported += batch.Count;
if (filteredData.Count > 1000)
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
}
catch
catch (Exception batchEx)
{
transaction.Rollback();
_logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message);
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
@@ -239,11 +263,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Checks if a table exists in the MariaDB database.
/// </summary>
public bool TableExists(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -255,7 +284,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
command.Parameters.AddWithValue("@database", _database);
command.Parameters.AddWithValue("@tableName", tableName);
var count = Convert.ToInt32(command.ExecuteScalar());
var count = command.GetScalarValue<int>(0, _logger, $"table existence check for {tableName}");
return count > 0;
}
catch (Exception ex)
@@ -265,17 +294,22 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Gets the row count for a specific table.
/// </summary>
public int GetTableRowCount(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"SELECT COUNT(*) FROM `{tableName}`";
using var command = new MySqlCommand(query, _connection);
return Convert.ToInt32(command.ExecuteScalar());
return command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
}
catch (Exception ex)
{
@@ -284,11 +318,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Drops a table from the MariaDB database.
/// </summary>
public bool DropTable(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"DROP TABLE IF EXISTS `{tableName}`";
@@ -448,6 +487,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
return true; // MariaDB multi-row INSERT is optimized
}
/// <summary>
/// Imports data using optimized multi-row INSERT statements for better performance.
/// </summary>
public bool ImportDataBulk(
string tableName,
List<string> columns,
@@ -456,6 +498,8 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -471,7 +515,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
return false;
}
// Filter columns
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
@@ -479,6 +522,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
{
if (actualColumns.Contains(columns[i]))
{
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
@@ -496,10 +540,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using multi-row INSERT", filteredData.Count, tableName);
// Use multi-row INSERT for better performance
// INSERT INTO table (col1, col2) VALUES (val1, val2), (val3, val4), ...
// MariaDB can handle up to max_allowed_packet size, we'll use 5000 rows per batch
const int rowsPerBatch = 5000;
const int rowsPerBatch = DbSeederConstants.LARGE_BATCH_SIZE;
var totalImported = 0;
for (int i = 0; i < filteredData.Count; i += rowsPerBatch)
@@ -536,7 +577,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
var fullInsertSql = columnPart + string.Join(", ", valueSets);
using var command = new MySqlCommand(fullInsertSql, _connection, transaction);
command.CommandTimeout = 300; // 5 minutes timeout for large batches
command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT;
// Add all parameters
foreach (var (name, value) in allParameters)
@@ -562,7 +603,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
transaction.Commit();
totalImported += batch.Count;
if (filteredData.Count > 1000)
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
@@ -570,19 +611,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
catch (Exception batchEx)
{
_logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message);
try
{
if (_connection?.State == System.Data.ConnectionState.Open)
{
transaction.Rollback();
}
}
catch (Exception rollbackEx)
{
_logger.LogWarning("Could not rollback transaction (connection may be closed): {Message}", rollbackEx.Message);
}
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
@@ -602,6 +631,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Tests the connection to MariaDB by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
@@ -609,9 +641,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
if (Connect())
{
using var command = new MySqlCommand("SELECT 1", _connection);
var result = command.ExecuteScalar();
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result != null && Convert.ToInt32(result) == 1;
return result == 1;
}
return false;
}
@@ -622,8 +654,27 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}
}
/// <summary>
/// Disposes of the MariaDB importer and releases all resources.
/// </summary>
public void Dispose()
{
Disconnect();
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Implements Dispose pattern for resource cleanup.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Disconnect();
}
_disposed = true;
}
}
}

View File

@@ -1,10 +1,14 @@
using Npgsql;
using NpgsqlTypes;
using Bit.Seeder.Migration.Models;
using Bit.Seeder.Migration.Utils;
using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// PostgreSQL database importer that handles schema creation, data import, and constraint management.
/// </summary>
public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> logger) : IDatabaseImporter
{
private readonly ILogger<PostgresImporter> _logger = logger;
@@ -14,16 +18,22 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private NpgsqlConnection? _connection;
private bool _disposed = false;
/// <summary>
/// Connects to the PostgreSQL database.
/// </summary>
public bool Connect()
{
try
{
var connectionString = $"Host={_host};Port={_port};Database={_database};" +
$"Username={_username};Password={_password};" +
$"Timeout=30;CommandTimeout=30;";
var safeConnectionString = $"Host={_host};Port={_port};Database={_database};" +
$"Username={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
$"Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};CommandTimeout={DbSeederConstants.DEFAULT_COMMAND_TIMEOUT};";
_connection = new NpgsqlConnection(connectionString);
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
_connection = new NpgsqlConnection(actualConnectionString);
_connection.Open();
_logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database);
@@ -47,6 +57,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}
}
/// <summary>
/// Creates a table from the provided schema definition.
/// </summary>
public bool CreateTableFromSchema(
string tableName,
List<string> columns,
@@ -58,12 +71,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Convert SQL Server types to PostgreSQL types
var pgColumns = new List<string>();
foreach (var colName in columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName));
pgColumns.Add($"\"{colName}\" {pgType}");
@@ -185,15 +202,20 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}
}
/// <summary>
/// Imports data into a table using batch INSERT statements.
/// </summary>
public bool ImportData(
string tableName,
List<string> columns,
List<object[]> data,
int batchSize = 1000)
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -210,6 +232,8 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
return false;
}
IdentifierValidator.ValidateOrThrow(actualTableName, "actual table name");
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
@@ -319,14 +343,14 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
transaction.Commit();
totalImported += batch.Count;
if (filteredData.Count > 1000)
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
}
catch
{
transaction.Rollback();
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
@@ -346,11 +370,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}
}
/// <summary>
/// Checks if a table exists in the database.
/// </summary>
public bool TableExists(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -362,7 +391,7 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
return (bool)command.ExecuteScalar()!;
return command.GetScalarValue<bool>(false, _logger, $"table existence check for {tableName}");
}
catch (Exception ex)
{
@@ -756,6 +785,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}
}
/// <summary>
/// Tests the connection to PostgreSQL by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
@@ -763,9 +795,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
if (Connect())
{
using var command = new NpgsqlCommand("SELECT 1", _connection);
var result = command.ExecuteScalar();
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result != null && (int)result == 1;
return result == 1;
}
return false;
}
@@ -776,8 +808,27 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}
}
/// <summary>
/// Disposes of the PostgreSQL importer and releases all resources.
/// </summary>
public void Dispose()
{
Disconnect();
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Implements Dispose pattern for resource cleanup.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Disconnect();
}
_disposed = true;
}
}
}

View File

@@ -1,9 +1,13 @@
using Microsoft.Data.SqlClient;
using Bit.Seeder.Migration.Models;
using Bit.Seeder.Migration.Utils;
using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// SQL Server database exporter that handles schema discovery and data export.
/// </summary>
public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter> logger) : IDisposable
{
private readonly ILogger<SqlServerExporter> _logger = logger;
@@ -13,16 +17,23 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private SqlConnection? _connection;
private bool _disposed = false;
/// <summary>
/// Connects to the SQL Server database.
/// </summary>
public bool Connect()
{
try
{
var connectionString = $"Server={_host},{_port};Database={_database};" +
$"User Id={_username};Password={_password};" +
$"TrustServerCertificate=True;Connection Timeout=30;";
var safeConnectionString = $"Server={_host},{_port};Database={_database};" +
$"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
$"TrustServerCertificate=True;" +
$"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};";
_connection = new SqlConnection(connectionString);
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
_connection = new SqlConnection(actualConnectionString);
_connection.Open();
_logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database);
@@ -35,6 +46,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
/// <summary>
/// Disconnects from the SQL Server database.
/// </summary>
public void Disconnect()
{
if (_connection != null)
@@ -46,6 +60,11 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
/// <summary>
/// Discovers all tables in the SQL Server database.
/// </summary>
/// <param name="excludeSystemTables">Whether to exclude system tables</param>
/// <returns>List of table names</returns>
public List<string> DiscoverTables(bool excludeSystemTables = true)
{
if (_connection == null)
@@ -86,11 +105,18 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
/// <summary>
/// Gets detailed information about a table including columns, types, and row count.
/// </summary>
/// <param name="tableName">The name of the table to query</param>
/// <returns>TableInfo containing schema and metadata</returns>
public TableInfo GetTableInfo(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Get column information
@@ -146,13 +172,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
throw new InvalidOperationException($"Table '{tableName}' not found");
}
// Get row count
var countQuery = $"SELECT COUNT(*) FROM [{tableName}]";
int rowCount;
using (var command = new SqlCommand(countQuery, _connection))
{
rowCount = (int)command.ExecuteScalar()!;
rowCount = command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
}
_logger.LogInformation("Table {TableName}: {ColumnCount} columns, {RowCount} rows", tableName, columns.Count, rowCount);
@@ -172,16 +197,32 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
public (List<string> Columns, List<object[]> Data) ExportTableData(string tableName, int batchSize = 10000)
/// <summary>
/// Exports all data from a table with streaming to avoid memory exhaustion.
/// </summary>
/// <param name="tableName">The name of the table to export</param>
/// <param name="batchSize">Batch size for progress reporting</param>
/// <returns>Tuple of column names and data rows</returns>
public (List<string> Columns, List<object[]> Data) ExportTableData(
string tableName,
int batchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Get table info first
var tableInfo = GetTableInfo(tableName);
// Validate all column names
foreach (var colName in tableInfo.Columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
}
// Build column list with proper quoting
var quotedColumns = tableInfo.Columns.Select(col => $"[{col}]").ToList();
var columnList = string.Join(", ", quotedColumns);
@@ -191,16 +232,24 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
_logger.LogInformation("Executing export query for {TableName}", tableName);
using var command = new SqlCommand(query, _connection);
command.CommandTimeout = 300; // 5 minutes
command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT;
using var reader = command.ExecuteReader();
// Fetch data in batches
// Identify GUID columns for in-place uppercase conversion
var guidColumnIndices = IdentifyGuidColumns(tableInfo);
// Fetch data in batches - still loads into memory but with progress reporting
// Note: For true streaming, consumers should use yield return pattern
var allData = new List<object[]>();
while (reader.Read())
{
var row = new object[tableInfo.Columns.Count];
reader.GetValues(row);
// Convert GUID values in-place to uppercase for Bitwarden compatibility
ConvertGuidsToUppercaseInPlace(row, guidColumnIndices);
allData.Add(row);
if (allData.Count % batchSize == 0)
@@ -209,11 +258,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
// Convert GUID values to uppercase to ensure compatibility with Bitwarden
var processedData = ConvertGuidsToUppercase(allData, tableInfo);
_logger.LogInformation("Exported {Count} rows from {TableName}", processedData.Count, tableName);
return (tableInfo.Columns, processedData);
_logger.LogInformation("Exported {Count} rows from {TableName}", allData.Count, tableName);
return (tableInfo.Columns, allData);
}
catch (Exception ex)
{
@@ -222,11 +268,21 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
public List<string> IdentifyJsonColumns(string tableName, int sampleSize = 100)
/// <summary>
/// Identifies columns that likely contain JSON data by sampling values.
/// </summary>
/// <param name="tableName">The name of the table to analyze</param>
/// <param name="sampleSize">Number of rows to sample for analysis</param>
/// <returns>List of column names that appear to contain JSON</returns>
public List<string> IdentifyJsonColumns(
string tableName,
int sampleSize = DbSeederConstants.DEFAULT_SAMPLE_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var tableInfo = GetTableInfo(tableName);
@@ -243,6 +299,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
if (textColumns.Count == 0)
return jsonColumns;
// Validate all column names
foreach (var colName in textColumns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
}
// Sample data from text columns
var quotedColumns = textColumns.Select(col => $"[{col}]").ToList();
var columnList = string.Join(", ", quotedColumns);
@@ -266,6 +328,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
// Analyze each column
// JSON detection threshold: 50% of samples must look like JSON
const double jsonThreshold = 0.5;
for (int i = 0; i < textColumns.Count; i++)
{
var colName = textColumns[i];
@@ -288,8 +353,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
// If more than 50% of non-null values look like JSON, mark as JSON column
if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > 0.5)
// If more than threshold of non-null values look like JSON, mark as JSON column
if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > jsonThreshold)
{
jsonColumns.Add(colName);
_logger.LogInformation("Identified {ColumnName} as likely JSON column ({JsonIndicators}/{TotalNonNull} samples)", colName, jsonIndicators, totalNonNull);
@@ -305,13 +370,15 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
private List<object[]> ConvertGuidsToUppercase(List<object[]> data, TableInfo tableInfo)
/// <summary>
/// Identifies GUID column indices in a table for efficient processing.
/// </summary>
/// <param name="tableInfo">Table metadata including column types</param>
/// <returns>List of column indices that are GUID/uniqueidentifier columns</returns>
private List<int> IdentifyGuidColumns(TableInfo tableInfo)
{
if (data.Count == 0 || tableInfo.ColumnTypes.Count == 0)
return data;
// Identify GUID columns (uniqueidentifier type in SQL Server)
var guidColumnIndices = new List<int>();
for (int i = 0; i < tableInfo.Columns.Count; i++)
{
var columnName = tableInfo.Columns[i];
@@ -325,34 +392,36 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
if (guidColumnIndices.Count == 0)
if (guidColumnIndices.Count > 0)
{
_logger.LogDebug("No GUID columns found, returning data unchanged");
return data;
_logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count);
}
_logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count);
// Process each row and convert GUID values to uppercase
var processedData = new List<object[]>();
foreach (var row in data)
{
var rowList = row.ToList();
foreach (var guidIdx in guidColumnIndices)
{
if (guidIdx < rowList.Count && rowList[guidIdx] != null && rowList[guidIdx] != DBNull.Value)
{
var guidValue = rowList[guidIdx].ToString();
// Convert to uppercase, preserving the GUID format
rowList[guidIdx] = guidValue?.ToUpper() ?? string.Empty;
}
}
processedData.Add(rowList.ToArray());
}
return processedData;
return guidColumnIndices;
}
/// <summary>
/// Converts GUID values to uppercase in-place within a data row.
/// More efficient than creating new arrays as it modifies the array directly.
/// </summary>
/// <param name="row">The data row to modify</param>
/// <param name="guidColumnIndices">Indices of columns containing GUIDs</param>
private void ConvertGuidsToUppercaseInPlace(object[] row, List<int> guidColumnIndices)
{
foreach (var guidIdx in guidColumnIndices)
{
if (guidIdx < row.Length && row[guidIdx] != null && row[guidIdx] != DBNull.Value)
{
var guidValue = row[guidIdx].ToString();
// Convert to uppercase in-place, preserving the GUID format
row[guidIdx] = guidValue?.ToUpper() ?? string.Empty;
}
}
}
/// <summary>
/// Tests the connection to SQL Server by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
@@ -360,9 +429,10 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
if (Connect())
{
using var command = new SqlCommand("SELECT 1", _connection);
var result = command.ExecuteScalar();
// Use null-safe scalar value retrieval
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result != null && (int)result == 1;
return result == 1;
}
return false;
}
@@ -373,8 +443,27 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
}
}
/// <summary>
/// Disposes of the SQL Server exporter and releases all resources.
/// </summary>
public void Dispose()
{
Disconnect();
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Implements Dispose pattern for resource cleanup.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Disconnect();
}
_disposed = true;
}
}
}

View File

@@ -6,6 +6,9 @@ using System.Data;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// SQL Server database importer that handles schema creation, data import, and constraint management.
/// </summary>
public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter> logger) : IDatabaseImporter
{
private readonly ILogger<SqlServerImporter> _logger = logger;
@@ -15,17 +18,24 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private SqlConnection? _connection;
private bool _disposed = false;
private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]";
/// <summary>
/// Connects to the SQL Server database.
/// </summary>
public bool Connect()
{
try
{
var connectionString = $"Server={_host},{_port};Database={_database};" +
$"User Id={_username};Password={_password};" +
$"TrustServerCertificate=True;Connection Timeout=30;";
var safeConnectionString = $"Server={_host},{_port};Database={_database};" +
$"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
$"TrustServerCertificate=True;" +
$"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};";
_connection = new SqlConnection(connectionString);
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
_connection = new SqlConnection(actualConnectionString);
_connection.Open();
_logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database);
@@ -38,6 +48,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Disconnects from the SQL Server database.
/// </summary>
public void Disconnect()
{
if (_connection != null)
@@ -49,11 +62,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Gets the list of columns for a table.
/// </summary>
public List<string> GetTableColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -81,11 +99,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Gets the column types for a table.
/// </summary>
private Dictionary<string, string> GetTableColumnTypes(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -112,11 +135,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Checks if a table exists in the database.
/// </summary>
public bool TableExists(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -127,7 +155,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
using var command = new SqlCommand(query, _connection);
command.Parameters.AddWithValue("@TableName", tableName);
var count = (int)command.ExecuteScalar()!;
var count = command.GetScalarValue<int>(0, _logger, $"table existence check for {tableName}");
return count > 0;
}
catch (Exception ex)
@@ -137,17 +165,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Gets the row count for a specific table.
/// </summary>
public int GetTableRowCount(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"SELECT COUNT(*) FROM [{tableName}]";
using var command = new SqlCommand(query, _connection);
var count = (int)command.ExecuteScalar()!;
var count = command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
_logger.LogDebug("Row count for {TableName}: {Count}", tableName, count);
return count;
}
@@ -158,11 +191,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Drops a table from the database.
/// </summary>
public bool DropTable(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"DROP TABLE IF EXISTS [{tableName}]";
@@ -200,6 +238,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Gets the list of constraints that need to be re-enabled from the tracking table.
/// </summary>
private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable()
{
if (_connection == null)
@@ -212,7 +253,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
// Check if tracking table exists
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
using var checkCommand = new SqlCommand(checkSql, _connection);
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
var count = checkCommand.GetScalarValue<int>(0, _logger, "tracking table existence check");
var tableExists = count > 0;
if (!tableExists)
{
@@ -248,6 +291,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
return constraints;
}
/// <summary>
/// Disables all foreign key constraints and tracks them for re-enabling.
/// </summary>
public bool DisableForeignKeys()
{
if (_connection == null)
@@ -261,7 +307,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
using (var checkCommand = new SqlCommand(checkSql, _connection))
{
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
var count = checkCommand.GetScalarValue<int>(0, _logger, "tracking table existence check");
var tableExists = count > 0;
if (tableExists)
{
@@ -366,6 +413,11 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
try
{
// Validate identifiers to prevent SQL injection
IdentifierValidator.ValidateOrThrow(schema, "schema name");
IdentifierValidator.ValidateOrThrow(table, "table name");
IdentifierValidator.ValidateOrThrow(constraint, "constraint name");
// Disable the constraint
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
using var disableCommand = new SqlCommand(disableSql, _connection, transaction);
@@ -401,7 +453,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
// If anything fails, rollback the transaction
// This ensures the tracking table doesn't exist with incomplete data
transaction.Rollback();
// Safely rollback transaction, preserving original exception
transaction.SafeRollback(_connection, _logger, "foreign key constraint disabling");
throw;
}
}
@@ -412,6 +465,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Re-enables all foreign key constraints that were disabled.
/// </summary>
public bool EnableForeignKeys()
{
if (_connection == null)
@@ -437,6 +493,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
try
{
IdentifierValidator.ValidateOrThrow(schema, "schema name");
IdentifierValidator.ValidateOrThrow(table, "table name");
IdentifierValidator.ValidateOrThrow(constraint, "constraint name");
var enableSql = $"ALTER TABLE [{schema}].[{table}] CHECK CONSTRAINT [{constraint}]";
using var command = new SqlCommand(enableSql, _connection);
command.ExecuteNonQuery();
@@ -464,6 +524,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Creates a table from the provided schema definition.
/// </summary>
public bool CreateTableFromSchema(
string tableName,
List<string> columns,
@@ -475,12 +538,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Build column definitions
var sqlServerColumns = new List<string>();
foreach (var colName in columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
var colType = columnTypes.GetValueOrDefault(colName, "NVARCHAR(MAX)");
// If it's a special JSON column, ensure it's a large text type
@@ -516,11 +583,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Gets the list of identity columns for a table.
/// </summary>
public List<string> GetIdentityColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
@@ -548,11 +620,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Enables IDENTITY_INSERT for a table to allow explicit identity values.
/// </summary>
public bool EnableIdentityInsert(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"SET IDENTITY_INSERT [{tableName}] ON";
@@ -569,11 +646,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Disables IDENTITY_INSERT for a table.
/// </summary>
public bool DisableIdentityInsert(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"SET IDENTITY_INSERT [{tableName}] OFF";
@@ -590,15 +672,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Imports data into a table using batch insert statements.
/// </summary>
public bool ImportData(
string tableName,
List<string> columns,
List<object[]> data,
int batchSize = 1000)
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -624,6 +711,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
if (actualColumns.Contains(columns[i]))
{
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
@@ -689,7 +778,6 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
// Validate that data was actually inserted
var actualCount = GetTableRowCount(tableName);
_logger.LogInformation("Post-import validation for {TableName}: imported {TotalImported}, table contains {ActualCount}", tableName, totalImported, actualCount);
@@ -708,15 +796,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Imports data using SqlBulkCopy for high performance.
/// </summary>
private int UseSqlBulkCopy(string tableName, List<string> columns, List<object?[]> data)
{
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
using var bulkCopy = new SqlBulkCopy(_connection!)
{
DestinationTableName = $"[{tableName}]",
BatchSize = 10000,
BulkCopyTimeout = 600 // 10 minutes
BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL,
BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT
};
// Map columns
@@ -747,12 +840,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
catch (Exception ex)
{
_logger.LogWarning("SqlBulkCopy failed: {Message}, falling back to batch insert", ex.Message);
return FastBatchImport(tableName, columns, data, 1000);
return FastBatchImport(tableName, columns, data, DbSeederConstants.DEFAULT_BATCH_SIZE);
}
}
/// <summary>
/// Imports data using fast batch INSERT statements with transactions.
/// </summary>
private int FastBatchImport(string tableName, List<string> columns, List<object?[]> data, int batchSize)
{
IdentifierValidator.ValidateOrThrow(tableName, "table name");
foreach (var column in columns)
{
IdentifierValidator.ValidateOrThrow(column, "column name");
}
var quotedColumns = columns.Select(col => $"[{col}]").ToList();
var placeholders = string.Join(", ", columns.Select((_, i) => $"@p{i}"));
var insertSql = $"INSERT INTO [{tableName}] ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})";
@@ -782,14 +885,15 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
transaction.Commit();
totalImported += batch.Count;
if (data.Count > 1000)
if (data.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{DataCount} total, {Percentage:F1}%)", batch.Count, totalImported, data.Count, (totalImported / (double)data.Count * 100));
}
}
catch
{
transaction.Rollback();
// Safely rollback transaction, preserving original exception
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
@@ -938,6 +1042,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
return true; // SQL Server SqlBulkCopy is highly optimized
}
/// <summary>
/// Imports data using SqlBulkCopy for high-performance bulk loading.
/// </summary>
public bool ImportDataBulk(
string tableName,
List<string> columns,
@@ -946,6 +1053,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -971,6 +1080,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
if (actualColumns.Contains(columns[i]))
{
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
@@ -1022,8 +1133,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
using var bulkCopy = new SqlBulkCopy(_connection, bulkCopyOptions, null)
{
DestinationTableName = $"[{tableName}]",
BatchSize = 10000,
BulkCopyTimeout = 600 // 10 minutes
BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL,
BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT
};
// Map columns
@@ -1063,6 +1174,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Tests the connection to SQL Server by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
@@ -1070,9 +1184,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
if (Connect())
{
using var command = new SqlCommand("SELECT 1", _connection);
var result = command.ExecuteScalar();
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result != null && (int)result == 1;
return result == 1;
}
return false;
}
@@ -1083,8 +1197,27 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
/// <summary>
/// Disposes of the SQL Server importer and releases all resources.
/// </summary>
public void Dispose()
{
Disconnect();
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Implements Dispose pattern for resource cleanup.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Disconnect();
}
_disposed = true;
}
}
}

View File

@@ -5,11 +5,15 @@ using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// SQLite database importer that handles schema creation and data import.
/// </summary>
public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logger) : IDatabaseImporter
{
private readonly ILogger<SqliteImporter> _logger = logger;
private readonly string _databasePath = config.Database;
private SqliteConnection? _connection;
private bool _disposed = false;
public bool Connect()
{
@@ -74,6 +78,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Creates a table from the provided schema definition.
/// </summary>
public bool CreateTableFromSchema(
string tableName,
List<string> columns,
@@ -85,11 +92,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var sqliteColumns = new List<string>();
foreach (var colName in columns)
{
// Validate each column name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var sqliteType = ConvertSqlServerTypeToSQLite(sqlServerType, specialColumns.Contains(colName));
sqliteColumns.Add($"\"{colName}\" {sqliteType}");
@@ -116,11 +129,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Gets the list of columns for a table.
/// </summary>
public List<string> GetTableColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"PRAGMA table_info(\"{tableName}\")";
@@ -142,15 +161,21 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Imports data into a table using batch INSERT statements.
/// </summary>
public bool ImportData(
string tableName,
List<string> columns,
List<object[]> data,
int batchSize = 1000)
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
@@ -174,6 +199,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
{
if (actualColumns.Contains(columns[i]))
{
// Validate column name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
@@ -231,7 +259,7 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
totalImported += batch.Count;
if (filteredData.Count > 1000)
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
@@ -244,7 +272,8 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
catch
{
transaction.Rollback();
// Safely rollback transaction, preserving original exception
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
@@ -255,18 +284,25 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Checks if a table exists in the database.
/// </summary>
public bool TableExists(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name = @tableName";
using var command = new SqliteCommand(query, _connection);
command.Parameters.AddWithValue("@tableName", tableName);
var count = Convert.ToInt64(command.ExecuteScalar());
// Use null-safe scalar value retrieval
var count = command.GetScalarValue<long>(0, _logger, $"table existence check for {tableName}");
return count > 0;
}
catch (Exception ex)
@@ -276,17 +312,24 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Gets the row count for a specific table.
/// </summary>
public int GetTableRowCount(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"SELECT COUNT(*) FROM \"{tableName}\"";
using var command = new SqliteCommand(query, _connection);
return Convert.ToInt32(command.ExecuteScalar());
// Use null-safe scalar value retrieval
return command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
}
catch (Exception ex)
{
@@ -295,11 +338,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Drops a table from the database.
/// </summary>
public bool DropTable(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"DROP TABLE IF EXISTS \"{tableName}\"";
@@ -432,6 +481,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
return false;
}
/// <summary>
/// Tests the connection to SQLite by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
@@ -439,9 +491,10 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
if (Connect())
{
using var command = new SqliteCommand("SELECT 1", _connection);
var result = command.ExecuteScalar();
// Use null-safe scalar value retrieval
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result != null && Convert.ToInt32(result) == 1;
return result == 1;
}
return false;
}
@@ -452,8 +505,29 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}
}
/// <summary>
/// Disposes of the SQLite importer and releases all resources.
/// </summary>
public void Dispose()
{
Disconnect();
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Protected implementation of Dispose pattern.
/// </summary>
/// <param name="disposing">True if called from Dispose(), false if called from finalizer</param>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
// Dispose managed resources
Disconnect();
}
_disposed = true;
}
}
}

View File

@@ -0,0 +1,123 @@
using System.Data;
using System.Data.Common;
using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Utils;
/// <summary>
/// Extension methods for database operations with improved null safety.
/// </summary>
public static class DbExtensions
{
/// <summary>
/// Executes a scalar query and returns the result with proper null handling.
/// Safely handles null and DBNull.Value results without throwing exceptions.
/// </summary>
/// <typeparam name="T">The expected return type</typeparam>
/// <param name="command">The database command to execute</param>
/// <param name="defaultValue">The default value to return if result is null/DBNull</param>
/// <param name="logger">Optional logger for warnings</param>
/// <param name="context">Optional context for logging (e.g., "table count for Users")</param>
/// <returns>The scalar value cast to type T, or defaultValue if null/DBNull</returns>
public static T GetScalarValue<T>(
this DbCommand command,
T defaultValue = default!,
ILogger? logger = null,
string? context = null)
{
var result = command.ExecuteScalar();
if (result == null || result == DBNull.Value)
{
if (logger != null && !string.IsNullOrEmpty(context))
{
logger.LogDebug("Query returned null for {Context}, using default value", context);
}
return defaultValue;
}
try
{
// Handle direct cast if types match
if (result is T typedResult)
{
return typedResult;
}
// Handle conversion for compatible types
return (T)Convert.ChangeType(result, typeof(T));
}
catch (InvalidCastException ex)
{
if (logger != null)
{
logger.LogWarning(
"Could not cast result to {TargetType} for {Context}. Result type: {ActualType}. Error: {Error}",
typeof(T).Name,
context ?? "query",
result.GetType().Name,
ex.Message
);
}
return defaultValue;
}
catch (FormatException ex)
{
if (logger != null)
{
logger.LogWarning(
"Format error converting result to {TargetType} for {Context}. Value: {Value}. Error: {Error}",
typeof(T).Name,
context ?? "query",
result,
ex.Message
);
}
return defaultValue;
}
}
/// <summary>
/// Safely attempts to rollback a transaction, catching and logging any errors.
/// This prevents rollback errors from masking the original exception.
/// </summary>
/// <param name="transaction">The transaction to rollback</param>
/// <param name="connection">The database connection (used to check if still open)</param>
/// <param name="logger">Logger for warnings</param>
/// <param name="context">Context for logging (e.g., table name)</param>
public static void SafeRollback(
this DbTransaction transaction,
DbConnection? connection,
ILogger logger,
string? context = null)
{
try
{
if (connection?.State == ConnectionState.Open)
{
transaction.Rollback();
if (!string.IsNullOrEmpty(context))
{
logger.LogDebug("Transaction rolled back for {Context}", context);
}
}
else
{
logger.LogWarning(
"Could not rollback transaction for {Context}: connection is {State}",
context ?? "operation",
connection?.State.ToString() ?? "null"
);
}
}
catch (Exception rollbackEx)
{
logger.LogWarning(
rollbackEx,
"Error during transaction rollback for {Context}: {Message}",
context ?? "operation",
rollbackEx.Message
);
}
}
}

View File

@@ -0,0 +1,58 @@
namespace Bit.Seeder.Migration.Utils;
/// <summary>
/// Constants used throughout the DbSeeder utility for database operations.
/// </summary>
public static class DbSeederConstants
{
/// <summary>
/// Default sample size for type detection and data analysis.
/// </summary>
public const int DEFAULT_SAMPLE_SIZE = 100;
/// <summary>
/// Default batch size for bulk insert operations.
/// </summary>
public const int DEFAULT_BATCH_SIZE = 1000;
/// <summary>
/// Large batch size for optimized bulk operations (use with caution for max_allowed_packet limits).
/// </summary>
public const int LARGE_BATCH_SIZE = 5000;
/// <summary>
/// Default connection timeout in seconds.
/// </summary>
public const int DEFAULT_CONNECTION_TIMEOUT = 30;
/// <summary>
/// Default command timeout in seconds for regular operations.
/// </summary>
public const int DEFAULT_COMMAND_TIMEOUT = 60;
/// <summary>
/// Extended command timeout in seconds for large batch operations.
/// </summary>
public const int LARGE_BATCH_COMMAND_TIMEOUT = 300; // 5 minutes
/// <summary>
/// Default maximum pool size for database connections.
/// </summary>
public const int DEFAULT_MAX_POOL_SIZE = 100;
/// <summary>
/// Threshold for enabling detailed progress logging (row count).
/// Operations with fewer rows may use simpler logging.
/// </summary>
public const int LOGGING_THRESHOLD = 1000;
/// <summary>
/// Batch size for progress reporting during long-running operations.
/// </summary>
public const int PROGRESS_REPORTING_INTERVAL = 10000;
/// <summary>
/// Placeholder text for redacting passwords in connection strings for safe logging.
/// </summary>
public const string REDACTED_PASSWORD = "***REDACTED***";
}

View File

@@ -0,0 +1,79 @@
using System.Text.RegularExpressions;
namespace Bit.Seeder.Migration.Utils;
/// <summary>
/// Validates SQL identifiers (table names, column names, schema names) to prevent SQL injection.
/// </summary>
public static class IdentifierValidator
{
// Regex pattern for valid SQL identifiers: must start with letter or underscore,
// followed by letters, numbers, or underscores
private static readonly Regex ValidIdentifierPattern = new Regex(
@"^[a-zA-Z_][a-zA-Z0-9_]*$",
RegexOptions.Compiled
);
// Maximum reasonable length for identifiers (most databases have limits around 128-255)
private const int MaxIdentifierLength = 128;
/// <summary>
/// Validates a SQL identifier (table name, column name, schema name).
/// </summary>
/// <param name="identifier">The identifier to validate</param>
/// <returns>True if the identifier is valid, false otherwise</returns>
public static bool IsValid(string? identifier)
{
if (string.IsNullOrWhiteSpace(identifier))
return false;
if (identifier.Length > MaxIdentifierLength)
return false;
return ValidIdentifierPattern.IsMatch(identifier);
}
/// <summary>
/// Validates a SQL identifier and throws an exception if invalid.
/// </summary>
/// <param name="identifier">The identifier to validate</param>
/// <param name="identifierType">The type of identifier (e.g., "table name", "column name")</param>
/// <exception cref="ArgumentException">Thrown when the identifier is invalid</exception>
public static void ValidateOrThrow(string? identifier, string identifierType = "identifier")
{
if (!IsValid(identifier))
{
throw new ArgumentException(
$"Invalid {identifierType}: '{identifier}'. " +
$"Identifiers must start with a letter or underscore, " +
$"contain only letters, numbers, and underscores, " +
$"and be no longer than {MaxIdentifierLength} characters.",
nameof(identifier)
);
}
}
/// <summary>
/// Validates multiple identifiers and throws an exception if any are invalid.
/// </summary>
/// <param name="identifiers">The identifiers to validate</param>
/// <param name="identifierType">The type of identifiers (e.g., "column names")</param>
/// <exception cref="ArgumentException">Thrown when any identifier is invalid</exception>
public static void ValidateAllOrThrow(IEnumerable<string> identifiers, string identifierType = "identifiers")
{
foreach (var identifier in identifiers)
{
ValidateOrThrow(identifier, identifierType);
}
}
/// <summary>
/// Filters a list of identifiers to only include valid ones.
/// </summary>
/// <param name="identifiers">The identifiers to filter</param>
/// <returns>A list of valid identifiers</returns>
public static List<string> FilterValid(IEnumerable<string> identifiers)
{
return identifiers.Where(IsValid).ToList();
}
}

View File

@@ -27,7 +27,8 @@ public static class SecuritySanitizer
foreach (var (key, value) in configDict)
{
if (SensitiveFields.Contains(key.ToLower()))
// Use case-insensitive comparison with proper culture handling
if (SensitiveFields.Any(field => field.Equals(key, StringComparison.OrdinalIgnoreCase)))
{
sanitized[key] = value != null ? MaskPassword(value.ToString() ?? string.Empty) : string.Empty;
}