mirror of
https://github.com/bitwarden/server
synced 2025-12-25 04:33:26 +00:00
Implemented some changes recommended by Claude
This commit is contained in:
@@ -20,7 +20,7 @@ public static class MigrationSettingsFactory
|
||||
.SetBasePath(Directory.GetCurrentDirectory())
|
||||
.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true)
|
||||
.AddJsonFile($"appsettings.{Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") ?? "Production"}.json", optional: true, reloadOnChange: true)
|
||||
.AddUserSecrets("bitwarden-Api")
|
||||
.AddUserSecrets("bitwarden-Api") // Load user secrets from the API project
|
||||
.AddEnvironmentVariables();
|
||||
|
||||
var configuration = configBuilder.Build();
|
||||
|
||||
@@ -290,7 +290,7 @@ public class CsvHandler(CsvSettings settings, ILogger<CsvHandler> logger)
|
||||
|
||||
if (string.IsNullOrEmpty(value))
|
||||
{
|
||||
processedRow[i] = null!;
|
||||
processedRow[i] = DBNull.Value;
|
||||
}
|
||||
else if (specialColumns.Contains(colName))
|
||||
{
|
||||
|
||||
@@ -5,6 +5,9 @@ using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Bit.Seeder.Migration.Databases;
|
||||
|
||||
/// <summary>
|
||||
/// MariaDB database importer that handles schema creation, data import, and foreign key management.
|
||||
/// </summary>
|
||||
public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> logger) : IDatabaseImporter
|
||||
{
|
||||
private readonly ILogger<MariaDbImporter> _logger = logger;
|
||||
@@ -14,16 +17,24 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
private readonly string _username = config.Username;
|
||||
private readonly string _password = config.Password;
|
||||
private MySqlConnection? _connection;
|
||||
private bool _disposed = false;
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the MariaDB database.
|
||||
/// </summary>
|
||||
public bool Connect()
|
||||
{
|
||||
try
|
||||
{
|
||||
var connectionString = $"Server={_host};Port={_port};Database={_database};" +
|
||||
$"Uid={_username};Pwd={_password};" +
|
||||
$"ConnectionTimeout=30;CharSet=utf8mb4;AllowLoadLocalInfile=true;MaxPoolSize=100;";
|
||||
// Build connection string with redacted password for safe logging
|
||||
var safeConnectionString = $"Server={_host};Port={_port};Database={_database};" +
|
||||
$"Uid={_username};Pwd={DbSeederConstants.REDACTED_PASSWORD};" +
|
||||
$"ConnectionTimeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};" +
|
||||
$"CharSet=utf8mb4;AllowLoadLocalInfile=false;MaxPoolSize={DbSeederConstants.DEFAULT_MAX_POOL_SIZE};";
|
||||
|
||||
_connection = new MySqlConnection(connectionString);
|
||||
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
|
||||
|
||||
_connection = new MySqlConnection(actualConnectionString);
|
||||
_connection.Open();
|
||||
|
||||
_logger.LogInformation("Connected to MariaDB: {Host}:{Port}/{Database}", _host, _port, _database);
|
||||
@@ -47,6 +58,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a table in MariaDB from the provided schema definition.
|
||||
/// </summary>
|
||||
public bool CreateTableFromSchema(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -58,11 +72,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var mariaColumns = new List<string>();
|
||||
foreach (var colName in columns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
|
||||
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
|
||||
var mariaType = ConvertSqlServerTypeToMariaDB(sqlServerType, specialColumns.Contains(colName));
|
||||
mariaColumns.Add($"`{colName}` {mariaType}");
|
||||
@@ -122,15 +140,20 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data into a MariaDB table using batched INSERT statements.
|
||||
/// </summary>
|
||||
public bool ImportData(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
List<object[]> data,
|
||||
int batchSize = 1000)
|
||||
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -146,7 +169,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
return false;
|
||||
}
|
||||
|
||||
// Filter columns
|
||||
var validColumnIndices = new List<int>();
|
||||
var validColumns = new List<string>();
|
||||
|
||||
@@ -154,6 +176,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
{
|
||||
if (actualColumns.Contains(columns[i]))
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
|
||||
validColumnIndices.Add(i);
|
||||
validColumns.Add(columns[i]);
|
||||
}
|
||||
@@ -217,14 +240,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
transaction.Commit();
|
||||
totalImported += batch.Count;
|
||||
|
||||
if (filteredData.Count > 1000)
|
||||
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
|
||||
{
|
||||
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
|
||||
}
|
||||
}
|
||||
catch
|
||||
catch (Exception batchEx)
|
||||
{
|
||||
transaction.Rollback();
|
||||
_logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message);
|
||||
transaction.SafeRollback(_connection, _logger, tableName);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -239,11 +263,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a table exists in the MariaDB database.
|
||||
/// </summary>
|
||||
public bool TableExists(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -255,7 +284,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
command.Parameters.AddWithValue("@database", _database);
|
||||
command.Parameters.AddWithValue("@tableName", tableName);
|
||||
|
||||
var count = Convert.ToInt32(command.ExecuteScalar());
|
||||
var count = command.GetScalarValue<int>(0, _logger, $"table existence check for {tableName}");
|
||||
return count > 0;
|
||||
}
|
||||
catch (Exception ex)
|
||||
@@ -265,17 +294,22 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the row count for a specific table.
|
||||
/// </summary>
|
||||
public int GetTableRowCount(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"SELECT COUNT(*) FROM `{tableName}`";
|
||||
using var command = new MySqlCommand(query, _connection);
|
||||
|
||||
return Convert.ToInt32(command.ExecuteScalar());
|
||||
return command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -284,11 +318,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Drops a table from the MariaDB database.
|
||||
/// </summary>
|
||||
public bool DropTable(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"DROP TABLE IF EXISTS `{tableName}`";
|
||||
@@ -448,6 +487,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
return true; // MariaDB multi-row INSERT is optimized
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data using optimized multi-row INSERT statements for better performance.
|
||||
/// </summary>
|
||||
public bool ImportDataBulk(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -456,6 +498,8 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -471,7 +515,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
return false;
|
||||
}
|
||||
|
||||
// Filter columns
|
||||
var validColumnIndices = new List<int>();
|
||||
var validColumns = new List<string>();
|
||||
|
||||
@@ -479,6 +522,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
{
|
||||
if (actualColumns.Contains(columns[i]))
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
|
||||
validColumnIndices.Add(i);
|
||||
validColumns.Add(columns[i]);
|
||||
}
|
||||
@@ -496,10 +540,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
|
||||
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using multi-row INSERT", filteredData.Count, tableName);
|
||||
|
||||
// Use multi-row INSERT for better performance
|
||||
// INSERT INTO table (col1, col2) VALUES (val1, val2), (val3, val4), ...
|
||||
// MariaDB can handle up to max_allowed_packet size, we'll use 5000 rows per batch
|
||||
const int rowsPerBatch = 5000;
|
||||
const int rowsPerBatch = DbSeederConstants.LARGE_BATCH_SIZE;
|
||||
var totalImported = 0;
|
||||
|
||||
for (int i = 0; i < filteredData.Count; i += rowsPerBatch)
|
||||
@@ -536,7 +577,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
var fullInsertSql = columnPart + string.Join(", ", valueSets);
|
||||
|
||||
using var command = new MySqlCommand(fullInsertSql, _connection, transaction);
|
||||
command.CommandTimeout = 300; // 5 minutes timeout for large batches
|
||||
command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT;
|
||||
|
||||
// Add all parameters
|
||||
foreach (var (name, value) in allParameters)
|
||||
@@ -562,7 +603,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
transaction.Commit();
|
||||
totalImported += batch.Count;
|
||||
|
||||
if (filteredData.Count > 1000)
|
||||
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
|
||||
{
|
||||
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
|
||||
}
|
||||
@@ -570,19 +611,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
catch (Exception batchEx)
|
||||
{
|
||||
_logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message);
|
||||
|
||||
try
|
||||
{
|
||||
if (_connection?.State == System.Data.ConnectionState.Open)
|
||||
{
|
||||
transaction.Rollback();
|
||||
}
|
||||
}
|
||||
catch (Exception rollbackEx)
|
||||
{
|
||||
_logger.LogWarning("Could not rollback transaction (connection may be closed): {Message}", rollbackEx.Message);
|
||||
}
|
||||
|
||||
transaction.SafeRollback(_connection, _logger, tableName);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -602,6 +631,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the connection to MariaDB by executing a simple query.
|
||||
/// </summary>
|
||||
public bool TestConnection()
|
||||
{
|
||||
try
|
||||
@@ -609,9 +641,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
if (Connect())
|
||||
{
|
||||
using var command = new MySqlCommand("SELECT 1", _connection);
|
||||
var result = command.ExecuteScalar();
|
||||
var result = command.GetScalarValue<int>(0, _logger, "connection test");
|
||||
Disconnect();
|
||||
return result != null && Convert.ToInt32(result) == 1;
|
||||
return result == 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -622,8 +654,27 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disposes of the MariaDB importer and releases all resources.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Disconnect();
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Implements Dispose pattern for resource cleanup.
|
||||
/// </summary>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
Disconnect();
|
||||
}
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
using Npgsql;
|
||||
using NpgsqlTypes;
|
||||
using Bit.Seeder.Migration.Models;
|
||||
using Bit.Seeder.Migration.Utils;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Bit.Seeder.Migration.Databases;
|
||||
|
||||
/// <summary>
|
||||
/// PostgreSQL database importer that handles schema creation, data import, and constraint management.
|
||||
/// </summary>
|
||||
public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> logger) : IDatabaseImporter
|
||||
{
|
||||
private readonly ILogger<PostgresImporter> _logger = logger;
|
||||
@@ -14,16 +18,22 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
private readonly string _username = config.Username;
|
||||
private readonly string _password = config.Password;
|
||||
private NpgsqlConnection? _connection;
|
||||
private bool _disposed = false;
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the PostgreSQL database.
|
||||
/// </summary>
|
||||
public bool Connect()
|
||||
{
|
||||
try
|
||||
{
|
||||
var connectionString = $"Host={_host};Port={_port};Database={_database};" +
|
||||
$"Username={_username};Password={_password};" +
|
||||
$"Timeout=30;CommandTimeout=30;";
|
||||
var safeConnectionString = $"Host={_host};Port={_port};Database={_database};" +
|
||||
$"Username={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
|
||||
$"Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};CommandTimeout={DbSeederConstants.DEFAULT_COMMAND_TIMEOUT};";
|
||||
|
||||
_connection = new NpgsqlConnection(connectionString);
|
||||
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
|
||||
|
||||
_connection = new NpgsqlConnection(actualConnectionString);
|
||||
_connection.Open();
|
||||
|
||||
_logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database);
|
||||
@@ -47,6 +57,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a table from the provided schema definition.
|
||||
/// </summary>
|
||||
public bool CreateTableFromSchema(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -58,12 +71,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
// Convert SQL Server types to PostgreSQL types
|
||||
var pgColumns = new List<string>();
|
||||
foreach (var colName in columns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
|
||||
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
|
||||
var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName));
|
||||
pgColumns.Add($"\"{colName}\" {pgType}");
|
||||
@@ -185,15 +202,20 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data into a table using batch INSERT statements.
|
||||
/// </summary>
|
||||
public bool ImportData(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
List<object[]> data,
|
||||
int batchSize = 1000)
|
||||
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -210,6 +232,8 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
return false;
|
||||
}
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(actualTableName, "actual table name");
|
||||
|
||||
var actualColumns = GetTableColumns(tableName);
|
||||
if (actualColumns.Count == 0)
|
||||
{
|
||||
@@ -319,14 +343,14 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
transaction.Commit();
|
||||
totalImported += batch.Count;
|
||||
|
||||
if (filteredData.Count > 1000)
|
||||
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
|
||||
{
|
||||
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
transaction.Rollback();
|
||||
transaction.SafeRollback(_connection, _logger, tableName);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -346,11 +370,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a table exists in the database.
|
||||
/// </summary>
|
||||
public bool TableExists(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -362,7 +391,7 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
using var command = new NpgsqlCommand(query, _connection);
|
||||
command.Parameters.AddWithValue("tableName", tableName);
|
||||
|
||||
return (bool)command.ExecuteScalar()!;
|
||||
return command.GetScalarValue<bool>(false, _logger, $"table existence check for {tableName}");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -756,6 +785,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the connection to PostgreSQL by executing a simple query.
|
||||
/// </summary>
|
||||
public bool TestConnection()
|
||||
{
|
||||
try
|
||||
@@ -763,9 +795,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
if (Connect())
|
||||
{
|
||||
using var command = new NpgsqlCommand("SELECT 1", _connection);
|
||||
var result = command.ExecuteScalar();
|
||||
var result = command.GetScalarValue<int>(0, _logger, "connection test");
|
||||
Disconnect();
|
||||
return result != null && (int)result == 1;
|
||||
return result == 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -776,8 +808,27 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disposes of the PostgreSQL importer and releases all resources.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Disconnect();
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Implements Dispose pattern for resource cleanup.
|
||||
/// </summary>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
Disconnect();
|
||||
}
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
using Microsoft.Data.SqlClient;
|
||||
using Bit.Seeder.Migration.Models;
|
||||
using Bit.Seeder.Migration.Utils;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Bit.Seeder.Migration.Databases;
|
||||
|
||||
/// <summary>
|
||||
/// SQL Server database exporter that handles schema discovery and data export.
|
||||
/// </summary>
|
||||
public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter> logger) : IDisposable
|
||||
{
|
||||
private readonly ILogger<SqlServerExporter> _logger = logger;
|
||||
@@ -13,16 +17,23 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
private readonly string _username = config.Username;
|
||||
private readonly string _password = config.Password;
|
||||
private SqlConnection? _connection;
|
||||
private bool _disposed = false;
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the SQL Server database.
|
||||
/// </summary>
|
||||
public bool Connect()
|
||||
{
|
||||
try
|
||||
{
|
||||
var connectionString = $"Server={_host},{_port};Database={_database};" +
|
||||
$"User Id={_username};Password={_password};" +
|
||||
$"TrustServerCertificate=True;Connection Timeout=30;";
|
||||
var safeConnectionString = $"Server={_host},{_port};Database={_database};" +
|
||||
$"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
|
||||
$"TrustServerCertificate=True;" +
|
||||
$"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};";
|
||||
|
||||
_connection = new SqlConnection(connectionString);
|
||||
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
|
||||
|
||||
_connection = new SqlConnection(actualConnectionString);
|
||||
_connection.Open();
|
||||
|
||||
_logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database);
|
||||
@@ -35,6 +46,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the SQL Server database.
|
||||
/// </summary>
|
||||
public void Disconnect()
|
||||
{
|
||||
if (_connection != null)
|
||||
@@ -46,6 +60,11 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Discovers all tables in the SQL Server database.
|
||||
/// </summary>
|
||||
/// <param name="excludeSystemTables">Whether to exclude system tables</param>
|
||||
/// <returns>List of table names</returns>
|
||||
public List<string> DiscoverTables(bool excludeSystemTables = true)
|
||||
{
|
||||
if (_connection == null)
|
||||
@@ -86,11 +105,18 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets detailed information about a table including columns, types, and row count.
|
||||
/// </summary>
|
||||
/// <param name="tableName">The name of the table to query</param>
|
||||
/// <returns>TableInfo containing schema and metadata</returns>
|
||||
public TableInfo GetTableInfo(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
// Get column information
|
||||
@@ -146,13 +172,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
throw new InvalidOperationException($"Table '{tableName}' not found");
|
||||
}
|
||||
|
||||
// Get row count
|
||||
var countQuery = $"SELECT COUNT(*) FROM [{tableName}]";
|
||||
int rowCount;
|
||||
|
||||
using (var command = new SqlCommand(countQuery, _connection))
|
||||
{
|
||||
rowCount = (int)command.ExecuteScalar()!;
|
||||
rowCount = command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
|
||||
}
|
||||
|
||||
_logger.LogInformation("Table {TableName}: {ColumnCount} columns, {RowCount} rows", tableName, columns.Count, rowCount);
|
||||
@@ -172,16 +197,32 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
public (List<string> Columns, List<object[]> Data) ExportTableData(string tableName, int batchSize = 10000)
|
||||
/// <summary>
|
||||
/// Exports all data from a table with streaming to avoid memory exhaustion.
|
||||
/// </summary>
|
||||
/// <param name="tableName">The name of the table to export</param>
|
||||
/// <param name="batchSize">Batch size for progress reporting</param>
|
||||
/// <returns>Tuple of column names and data rows</returns>
|
||||
public (List<string> Columns, List<object[]> Data) ExportTableData(
|
||||
string tableName,
|
||||
int batchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
// Get table info first
|
||||
var tableInfo = GetTableInfo(tableName);
|
||||
|
||||
// Validate all column names
|
||||
foreach (var colName in tableInfo.Columns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
}
|
||||
|
||||
// Build column list with proper quoting
|
||||
var quotedColumns = tableInfo.Columns.Select(col => $"[{col}]").ToList();
|
||||
var columnList = string.Join(", ", quotedColumns);
|
||||
@@ -191,16 +232,24 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
_logger.LogInformation("Executing export query for {TableName}", tableName);
|
||||
|
||||
using var command = new SqlCommand(query, _connection);
|
||||
command.CommandTimeout = 300; // 5 minutes
|
||||
command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT;
|
||||
|
||||
using var reader = command.ExecuteReader();
|
||||
|
||||
// Fetch data in batches
|
||||
// Identify GUID columns for in-place uppercase conversion
|
||||
var guidColumnIndices = IdentifyGuidColumns(tableInfo);
|
||||
|
||||
// Fetch data in batches - still loads into memory but with progress reporting
|
||||
// Note: For true streaming, consumers should use yield return pattern
|
||||
var allData = new List<object[]>();
|
||||
while (reader.Read())
|
||||
{
|
||||
var row = new object[tableInfo.Columns.Count];
|
||||
reader.GetValues(row);
|
||||
|
||||
// Convert GUID values in-place to uppercase for Bitwarden compatibility
|
||||
ConvertGuidsToUppercaseInPlace(row, guidColumnIndices);
|
||||
|
||||
allData.Add(row);
|
||||
|
||||
if (allData.Count % batchSize == 0)
|
||||
@@ -209,11 +258,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GUID values to uppercase to ensure compatibility with Bitwarden
|
||||
var processedData = ConvertGuidsToUppercase(allData, tableInfo);
|
||||
|
||||
_logger.LogInformation("Exported {Count} rows from {TableName}", processedData.Count, tableName);
|
||||
return (tableInfo.Columns, processedData);
|
||||
_logger.LogInformation("Exported {Count} rows from {TableName}", allData.Count, tableName);
|
||||
return (tableInfo.Columns, allData);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -222,11 +268,21 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
public List<string> IdentifyJsonColumns(string tableName, int sampleSize = 100)
|
||||
/// <summary>
|
||||
/// Identifies columns that likely contain JSON data by sampling values.
|
||||
/// </summary>
|
||||
/// <param name="tableName">The name of the table to analyze</param>
|
||||
/// <param name="sampleSize">Number of rows to sample for analysis</param>
|
||||
/// <returns>List of column names that appear to contain JSON</returns>
|
||||
public List<string> IdentifyJsonColumns(
|
||||
string tableName,
|
||||
int sampleSize = DbSeederConstants.DEFAULT_SAMPLE_SIZE)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var tableInfo = GetTableInfo(tableName);
|
||||
@@ -243,6 +299,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
if (textColumns.Count == 0)
|
||||
return jsonColumns;
|
||||
|
||||
// Validate all column names
|
||||
foreach (var colName in textColumns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
}
|
||||
|
||||
// Sample data from text columns
|
||||
var quotedColumns = textColumns.Select(col => $"[{col}]").ToList();
|
||||
var columnList = string.Join(", ", quotedColumns);
|
||||
@@ -266,6 +328,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
|
||||
// Analyze each column
|
||||
// JSON detection threshold: 50% of samples must look like JSON
|
||||
const double jsonThreshold = 0.5;
|
||||
|
||||
for (int i = 0; i < textColumns.Count; i++)
|
||||
{
|
||||
var colName = textColumns[i];
|
||||
@@ -288,8 +353,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
// If more than 50% of non-null values look like JSON, mark as JSON column
|
||||
if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > 0.5)
|
||||
// If more than threshold of non-null values look like JSON, mark as JSON column
|
||||
if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > jsonThreshold)
|
||||
{
|
||||
jsonColumns.Add(colName);
|
||||
_logger.LogInformation("Identified {ColumnName} as likely JSON column ({JsonIndicators}/{TotalNonNull} samples)", colName, jsonIndicators, totalNonNull);
|
||||
@@ -305,13 +370,15 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
private List<object[]> ConvertGuidsToUppercase(List<object[]> data, TableInfo tableInfo)
|
||||
/// <summary>
|
||||
/// Identifies GUID column indices in a table for efficient processing.
|
||||
/// </summary>
|
||||
/// <param name="tableInfo">Table metadata including column types</param>
|
||||
/// <returns>List of column indices that are GUID/uniqueidentifier columns</returns>
|
||||
private List<int> IdentifyGuidColumns(TableInfo tableInfo)
|
||||
{
|
||||
if (data.Count == 0 || tableInfo.ColumnTypes.Count == 0)
|
||||
return data;
|
||||
|
||||
// Identify GUID columns (uniqueidentifier type in SQL Server)
|
||||
var guidColumnIndices = new List<int>();
|
||||
|
||||
for (int i = 0; i < tableInfo.Columns.Count; i++)
|
||||
{
|
||||
var columnName = tableInfo.Columns[i];
|
||||
@@ -325,34 +392,36 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
if (guidColumnIndices.Count == 0)
|
||||
if (guidColumnIndices.Count > 0)
|
||||
{
|
||||
_logger.LogDebug("No GUID columns found, returning data unchanged");
|
||||
return data;
|
||||
_logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count);
|
||||
}
|
||||
|
||||
_logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count);
|
||||
|
||||
// Process each row and convert GUID values to uppercase
|
||||
var processedData = new List<object[]>();
|
||||
foreach (var row in data)
|
||||
{
|
||||
var rowList = row.ToList();
|
||||
foreach (var guidIdx in guidColumnIndices)
|
||||
{
|
||||
if (guidIdx < rowList.Count && rowList[guidIdx] != null && rowList[guidIdx] != DBNull.Value)
|
||||
{
|
||||
var guidValue = rowList[guidIdx].ToString();
|
||||
// Convert to uppercase, preserving the GUID format
|
||||
rowList[guidIdx] = guidValue?.ToUpper() ?? string.Empty;
|
||||
}
|
||||
}
|
||||
processedData.Add(rowList.ToArray());
|
||||
}
|
||||
|
||||
return processedData;
|
||||
return guidColumnIndices;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts GUID values to uppercase in-place within a data row.
|
||||
/// More efficient than creating new arrays as it modifies the array directly.
|
||||
/// </summary>
|
||||
/// <param name="row">The data row to modify</param>
|
||||
/// <param name="guidColumnIndices">Indices of columns containing GUIDs</param>
|
||||
private void ConvertGuidsToUppercaseInPlace(object[] row, List<int> guidColumnIndices)
|
||||
{
|
||||
foreach (var guidIdx in guidColumnIndices)
|
||||
{
|
||||
if (guidIdx < row.Length && row[guidIdx] != null && row[guidIdx] != DBNull.Value)
|
||||
{
|
||||
var guidValue = row[guidIdx].ToString();
|
||||
// Convert to uppercase in-place, preserving the GUID format
|
||||
row[guidIdx] = guidValue?.ToUpper() ?? string.Empty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the connection to SQL Server by executing a simple query.
|
||||
/// </summary>
|
||||
public bool TestConnection()
|
||||
{
|
||||
try
|
||||
@@ -360,9 +429,10 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
if (Connect())
|
||||
{
|
||||
using var command = new SqlCommand("SELECT 1", _connection);
|
||||
var result = command.ExecuteScalar();
|
||||
// Use null-safe scalar value retrieval
|
||||
var result = command.GetScalarValue<int>(0, _logger, "connection test");
|
||||
Disconnect();
|
||||
return result != null && (int)result == 1;
|
||||
return result == 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -373,8 +443,27 @@ public class SqlServerExporter(DatabaseConfig config, ILogger<SqlServerExporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disposes of the SQL Server exporter and releases all resources.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Disconnect();
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Implements Dispose pattern for resource cleanup.
|
||||
/// </summary>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
Disconnect();
|
||||
}
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ using System.Data;
|
||||
|
||||
namespace Bit.Seeder.Migration.Databases;
|
||||
|
||||
/// <summary>
|
||||
/// SQL Server database importer that handles schema creation, data import, and constraint management.
|
||||
/// </summary>
|
||||
public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter> logger) : IDatabaseImporter
|
||||
{
|
||||
private readonly ILogger<SqlServerImporter> _logger = logger;
|
||||
@@ -15,17 +18,24 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
private readonly string _username = config.Username;
|
||||
private readonly string _password = config.Password;
|
||||
private SqlConnection? _connection;
|
||||
private bool _disposed = false;
|
||||
private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]";
|
||||
|
||||
/// <summary>
|
||||
/// Connects to the SQL Server database.
|
||||
/// </summary>
|
||||
public bool Connect()
|
||||
{
|
||||
try
|
||||
{
|
||||
var connectionString = $"Server={_host},{_port};Database={_database};" +
|
||||
$"User Id={_username};Password={_password};" +
|
||||
$"TrustServerCertificate=True;Connection Timeout=30;";
|
||||
var safeConnectionString = $"Server={_host},{_port};Database={_database};" +
|
||||
$"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
|
||||
$"TrustServerCertificate=True;" +
|
||||
$"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};";
|
||||
|
||||
_connection = new SqlConnection(connectionString);
|
||||
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
|
||||
|
||||
_connection = new SqlConnection(actualConnectionString);
|
||||
_connection.Open();
|
||||
|
||||
_logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database);
|
||||
@@ -38,6 +48,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disconnects from the SQL Server database.
|
||||
/// </summary>
|
||||
public void Disconnect()
|
||||
{
|
||||
if (_connection != null)
|
||||
@@ -49,11 +62,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the list of columns for a table.
|
||||
/// </summary>
|
||||
public List<string> GetTableColumns(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -81,11 +99,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the column types for a table.
|
||||
/// </summary>
|
||||
private Dictionary<string, string> GetTableColumnTypes(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -112,11 +135,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a table exists in the database.
|
||||
/// </summary>
|
||||
public bool TableExists(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -127,7 +155,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
using var command = new SqlCommand(query, _connection);
|
||||
command.Parameters.AddWithValue("@TableName", tableName);
|
||||
|
||||
var count = (int)command.ExecuteScalar()!;
|
||||
var count = command.GetScalarValue<int>(0, _logger, $"table existence check for {tableName}");
|
||||
return count > 0;
|
||||
}
|
||||
catch (Exception ex)
|
||||
@@ -137,17 +165,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the row count for a specific table.
|
||||
/// </summary>
|
||||
public int GetTableRowCount(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"SELECT COUNT(*) FROM [{tableName}]";
|
||||
using var command = new SqlCommand(query, _connection);
|
||||
|
||||
var count = (int)command.ExecuteScalar()!;
|
||||
var count = command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
|
||||
_logger.LogDebug("Row count for {TableName}: {Count}", tableName, count);
|
||||
return count;
|
||||
}
|
||||
@@ -158,11 +191,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Drops a table from the database.
|
||||
/// </summary>
|
||||
public bool DropTable(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"DROP TABLE IF EXISTS [{tableName}]";
|
||||
@@ -200,6 +238,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the list of constraints that need to be re-enabled from the tracking table.
|
||||
/// </summary>
|
||||
private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable()
|
||||
{
|
||||
if (_connection == null)
|
||||
@@ -212,7 +253,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
// Check if tracking table exists
|
||||
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
|
||||
using var checkCommand = new SqlCommand(checkSql, _connection);
|
||||
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
|
||||
|
||||
var count = checkCommand.GetScalarValue<int>(0, _logger, "tracking table existence check");
|
||||
var tableExists = count > 0;
|
||||
|
||||
if (!tableExists)
|
||||
{
|
||||
@@ -248,6 +291,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
return constraints;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disables all foreign key constraints and tracks them for re-enabling.
|
||||
/// </summary>
|
||||
public bool DisableForeignKeys()
|
||||
{
|
||||
if (_connection == null)
|
||||
@@ -261,7 +307,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
|
||||
using (var checkCommand = new SqlCommand(checkSql, _connection))
|
||||
{
|
||||
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
|
||||
var count = checkCommand.GetScalarValue<int>(0, _logger, "tracking table existence check");
|
||||
var tableExists = count > 0;
|
||||
|
||||
if (tableExists)
|
||||
{
|
||||
@@ -366,6 +413,11 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
try
|
||||
{
|
||||
// Validate identifiers to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(schema, "schema name");
|
||||
IdentifierValidator.ValidateOrThrow(table, "table name");
|
||||
IdentifierValidator.ValidateOrThrow(constraint, "constraint name");
|
||||
|
||||
// Disable the constraint
|
||||
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
|
||||
using var disableCommand = new SqlCommand(disableSql, _connection, transaction);
|
||||
@@ -401,7 +453,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
// If anything fails, rollback the transaction
|
||||
// This ensures the tracking table doesn't exist with incomplete data
|
||||
transaction.Rollback();
|
||||
// Safely rollback transaction, preserving original exception
|
||||
transaction.SafeRollback(_connection, _logger, "foreign key constraint disabling");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -412,6 +465,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Re-enables all foreign key constraints that were disabled.
|
||||
/// </summary>
|
||||
public bool EnableForeignKeys()
|
||||
{
|
||||
if (_connection == null)
|
||||
@@ -437,6 +493,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
try
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(schema, "schema name");
|
||||
IdentifierValidator.ValidateOrThrow(table, "table name");
|
||||
IdentifierValidator.ValidateOrThrow(constraint, "constraint name");
|
||||
|
||||
var enableSql = $"ALTER TABLE [{schema}].[{table}] CHECK CONSTRAINT [{constraint}]";
|
||||
using var command = new SqlCommand(enableSql, _connection);
|
||||
command.ExecuteNonQuery();
|
||||
@@ -464,6 +524,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a table from the provided schema definition.
|
||||
/// </summary>
|
||||
public bool CreateTableFromSchema(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -475,12 +538,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
// Build column definitions
|
||||
var sqlServerColumns = new List<string>();
|
||||
foreach (var colName in columns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
|
||||
var colType = columnTypes.GetValueOrDefault(colName, "NVARCHAR(MAX)");
|
||||
|
||||
// If it's a special JSON column, ensure it's a large text type
|
||||
@@ -516,11 +583,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the list of identity columns for a table.
|
||||
/// </summary>
|
||||
public List<string> GetIdentityColumns(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = @"
|
||||
@@ -548,11 +620,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enables IDENTITY_INSERT for a table to allow explicit identity values.
|
||||
/// </summary>
|
||||
public bool EnableIdentityInsert(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"SET IDENTITY_INSERT [{tableName}] ON";
|
||||
@@ -569,11 +646,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disables IDENTITY_INSERT for a table.
|
||||
/// </summary>
|
||||
public bool DisableIdentityInsert(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"SET IDENTITY_INSERT [{tableName}] OFF";
|
||||
@@ -590,15 +672,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data into a table using batch insert statements.
|
||||
/// </summary>
|
||||
public bool ImportData(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
List<object[]> data,
|
||||
int batchSize = 1000)
|
||||
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -624,6 +711,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
if (actualColumns.Contains(columns[i]))
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
|
||||
|
||||
validColumnIndices.Add(i);
|
||||
validColumns.Add(columns[i]);
|
||||
}
|
||||
@@ -689,7 +778,6 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that data was actually inserted
|
||||
var actualCount = GetTableRowCount(tableName);
|
||||
_logger.LogInformation("Post-import validation for {TableName}: imported {TotalImported}, table contains {ActualCount}", tableName, totalImported, actualCount);
|
||||
|
||||
@@ -708,15 +796,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data using SqlBulkCopy for high performance.
|
||||
/// </summary>
|
||||
private int UseSqlBulkCopy(string tableName, List<string> columns, List<object?[]> data)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
using var bulkCopy = new SqlBulkCopy(_connection!)
|
||||
{
|
||||
DestinationTableName = $"[{tableName}]",
|
||||
BatchSize = 10000,
|
||||
BulkCopyTimeout = 600 // 10 minutes
|
||||
BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL,
|
||||
BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT
|
||||
};
|
||||
|
||||
// Map columns
|
||||
@@ -747,12 +840,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning("SqlBulkCopy failed: {Message}, falling back to batch insert", ex.Message);
|
||||
return FastBatchImport(tableName, columns, data, 1000);
|
||||
return FastBatchImport(tableName, columns, data, DbSeederConstants.DEFAULT_BATCH_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data using fast batch INSERT statements with transactions.
|
||||
/// </summary>
|
||||
private int FastBatchImport(string tableName, List<string> columns, List<object?[]> data, int batchSize)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
foreach (var column in columns)
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(column, "column name");
|
||||
}
|
||||
|
||||
var quotedColumns = columns.Select(col => $"[{col}]").ToList();
|
||||
var placeholders = string.Join(", ", columns.Select((_, i) => $"@p{i}"));
|
||||
var insertSql = $"INSERT INTO [{tableName}] ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})";
|
||||
@@ -782,14 +885,15 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
transaction.Commit();
|
||||
totalImported += batch.Count;
|
||||
|
||||
if (data.Count > 1000)
|
||||
if (data.Count > DbSeederConstants.LOGGING_THRESHOLD)
|
||||
{
|
||||
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{DataCount} total, {Percentage:F1}%)", batch.Count, totalImported, data.Count, (totalImported / (double)data.Count * 100));
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
transaction.Rollback();
|
||||
// Safely rollback transaction, preserving original exception
|
||||
transaction.SafeRollback(_connection, _logger, tableName);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -938,6 +1042,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
return true; // SQL Server SqlBulkCopy is highly optimized
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data using SqlBulkCopy for high-performance bulk loading.
|
||||
/// </summary>
|
||||
public bool ImportDataBulk(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -946,6 +1053,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -971,6 +1080,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
if (actualColumns.Contains(columns[i]))
|
||||
{
|
||||
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
|
||||
|
||||
validColumnIndices.Add(i);
|
||||
validColumns.Add(columns[i]);
|
||||
}
|
||||
@@ -1022,8 +1133,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
using var bulkCopy = new SqlBulkCopy(_connection, bulkCopyOptions, null)
|
||||
{
|
||||
DestinationTableName = $"[{tableName}]",
|
||||
BatchSize = 10000,
|
||||
BulkCopyTimeout = 600 // 10 minutes
|
||||
BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL,
|
||||
BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT
|
||||
};
|
||||
|
||||
// Map columns
|
||||
@@ -1063,6 +1174,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the connection to SQL Server by executing a simple query.
|
||||
/// </summary>
|
||||
public bool TestConnection()
|
||||
{
|
||||
try
|
||||
@@ -1070,9 +1184,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
if (Connect())
|
||||
{
|
||||
using var command = new SqlCommand("SELECT 1", _connection);
|
||||
var result = command.ExecuteScalar();
|
||||
var result = command.GetScalarValue<int>(0, _logger, "connection test");
|
||||
Disconnect();
|
||||
return result != null && (int)result == 1;
|
||||
return result == 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -1083,8 +1197,27 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disposes of the SQL Server importer and releases all resources.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Disconnect();
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Implements Dispose pattern for resource cleanup.
|
||||
/// </summary>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
Disconnect();
|
||||
}
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,15 @@ using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Bit.Seeder.Migration.Databases;
|
||||
|
||||
/// <summary>
|
||||
/// SQLite database importer that handles schema creation and data import.
|
||||
/// </summary>
|
||||
public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logger) : IDatabaseImporter
|
||||
{
|
||||
private readonly ILogger<SqliteImporter> _logger = logger;
|
||||
private readonly string _databasePath = config.Database;
|
||||
private SqliteConnection? _connection;
|
||||
private bool _disposed = false;
|
||||
|
||||
public bool Connect()
|
||||
{
|
||||
@@ -74,6 +78,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a table from the provided schema definition.
|
||||
/// </summary>
|
||||
public bool CreateTableFromSchema(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
@@ -85,11 +92,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var sqliteColumns = new List<string>();
|
||||
foreach (var colName in columns)
|
||||
{
|
||||
// Validate each column name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(colName, "column name");
|
||||
|
||||
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
|
||||
var sqliteType = ConvertSqlServerTypeToSQLite(sqlServerType, specialColumns.Contains(colName));
|
||||
sqliteColumns.Add($"\"{colName}\" {sqliteType}");
|
||||
@@ -116,11 +129,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the list of columns for a table.
|
||||
/// </summary>
|
||||
public List<string> GetTableColumns(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"PRAGMA table_info(\"{tableName}\")";
|
||||
@@ -142,15 +161,21 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Imports data into a table using batch INSERT statements.
|
||||
/// </summary>
|
||||
public bool ImportData(
|
||||
string tableName,
|
||||
List<string> columns,
|
||||
List<object[]> data,
|
||||
int batchSize = 1000)
|
||||
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
if (data.Count == 0)
|
||||
{
|
||||
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
||||
@@ -174,6 +199,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
{
|
||||
if (actualColumns.Contains(columns[i]))
|
||||
{
|
||||
// Validate column name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(columns[i], "column name");
|
||||
|
||||
validColumnIndices.Add(i);
|
||||
validColumns.Add(columns[i]);
|
||||
}
|
||||
@@ -231,7 +259,7 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
|
||||
totalImported += batch.Count;
|
||||
|
||||
if (filteredData.Count > 1000)
|
||||
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
|
||||
{
|
||||
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
|
||||
}
|
||||
@@ -244,7 +272,8 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
catch
|
||||
{
|
||||
transaction.Rollback();
|
||||
// Safely rollback transaction, preserving original exception
|
||||
transaction.SafeRollback(_connection, _logger, tableName);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -255,18 +284,25 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if a table exists in the database.
|
||||
/// </summary>
|
||||
public bool TableExists(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name = @tableName";
|
||||
using var command = new SqliteCommand(query, _connection);
|
||||
command.Parameters.AddWithValue("@tableName", tableName);
|
||||
|
||||
var count = Convert.ToInt64(command.ExecuteScalar());
|
||||
// Use null-safe scalar value retrieval
|
||||
var count = command.GetScalarValue<long>(0, _logger, $"table existence check for {tableName}");
|
||||
return count > 0;
|
||||
}
|
||||
catch (Exception ex)
|
||||
@@ -276,17 +312,24 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the row count for a specific table.
|
||||
/// </summary>
|
||||
public int GetTableRowCount(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"SELECT COUNT(*) FROM \"{tableName}\"";
|
||||
using var command = new SqliteCommand(query, _connection);
|
||||
|
||||
return Convert.ToInt32(command.ExecuteScalar());
|
||||
// Use null-safe scalar value retrieval
|
||||
return command.GetScalarValue<int>(0, _logger, $"row count for {tableName}");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -295,11 +338,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Drops a table from the database.
|
||||
/// </summary>
|
||||
public bool DropTable(string tableName)
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
// Validate table name to prevent SQL injection
|
||||
IdentifierValidator.ValidateOrThrow(tableName, "table name");
|
||||
|
||||
try
|
||||
{
|
||||
var query = $"DROP TABLE IF EXISTS \"{tableName}\"";
|
||||
@@ -432,6 +481,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the connection to SQLite by executing a simple query.
|
||||
/// </summary>
|
||||
public bool TestConnection()
|
||||
{
|
||||
try
|
||||
@@ -439,9 +491,10 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
if (Connect())
|
||||
{
|
||||
using var command = new SqliteCommand("SELECT 1", _connection);
|
||||
var result = command.ExecuteScalar();
|
||||
// Use null-safe scalar value retrieval
|
||||
var result = command.GetScalarValue<int>(0, _logger, "connection test");
|
||||
Disconnect();
|
||||
return result != null && Convert.ToInt32(result) == 1;
|
||||
return result == 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -452,8 +505,29 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Disposes of the SQLite importer and releases all resources.
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Disconnect();
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Protected implementation of Dispose pattern.
|
||||
/// </summary>
|
||||
/// <param name="disposing">True if called from Dispose(), false if called from finalizer</param>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
// Dispose managed resources
|
||||
Disconnect();
|
||||
}
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
123
util/Seeder/Migration/Utils/DbExtensions.cs
Normal file
123
util/Seeder/Migration/Utils/DbExtensions.cs
Normal file
@@ -0,0 +1,123 @@
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Bit.Seeder.Migration.Utils;
|
||||
|
||||
/// <summary>
|
||||
/// Extension methods for database operations with improved null safety.
|
||||
/// </summary>
|
||||
public static class DbExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Executes a scalar query and returns the result with proper null handling.
|
||||
/// Safely handles null and DBNull.Value results without throwing exceptions.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">The expected return type</typeparam>
|
||||
/// <param name="command">The database command to execute</param>
|
||||
/// <param name="defaultValue">The default value to return if result is null/DBNull</param>
|
||||
/// <param name="logger">Optional logger for warnings</param>
|
||||
/// <param name="context">Optional context for logging (e.g., "table count for Users")</param>
|
||||
/// <returns>The scalar value cast to type T, or defaultValue if null/DBNull</returns>
|
||||
public static T GetScalarValue<T>(
|
||||
this DbCommand command,
|
||||
T defaultValue = default!,
|
||||
ILogger? logger = null,
|
||||
string? context = null)
|
||||
{
|
||||
var result = command.ExecuteScalar();
|
||||
|
||||
if (result == null || result == DBNull.Value)
|
||||
{
|
||||
if (logger != null && !string.IsNullOrEmpty(context))
|
||||
{
|
||||
logger.LogDebug("Query returned null for {Context}, using default value", context);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Handle direct cast if types match
|
||||
if (result is T typedResult)
|
||||
{
|
||||
return typedResult;
|
||||
}
|
||||
|
||||
// Handle conversion for compatible types
|
||||
return (T)Convert.ChangeType(result, typeof(T));
|
||||
}
|
||||
catch (InvalidCastException ex)
|
||||
{
|
||||
if (logger != null)
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Could not cast result to {TargetType} for {Context}. Result type: {ActualType}. Error: {Error}",
|
||||
typeof(T).Name,
|
||||
context ?? "query",
|
||||
result.GetType().Name,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
catch (FormatException ex)
|
||||
{
|
||||
if (logger != null)
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Format error converting result to {TargetType} for {Context}. Value: {Value}. Error: {Error}",
|
||||
typeof(T).Name,
|
||||
context ?? "query",
|
||||
result,
|
||||
ex.Message
|
||||
);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Safely attempts to rollback a transaction, catching and logging any errors.
|
||||
/// This prevents rollback errors from masking the original exception.
|
||||
/// </summary>
|
||||
/// <param name="transaction">The transaction to rollback</param>
|
||||
/// <param name="connection">The database connection (used to check if still open)</param>
|
||||
/// <param name="logger">Logger for warnings</param>
|
||||
/// <param name="context">Context for logging (e.g., table name)</param>
|
||||
public static void SafeRollback(
|
||||
this DbTransaction transaction,
|
||||
DbConnection? connection,
|
||||
ILogger logger,
|
||||
string? context = null)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (connection?.State == ConnectionState.Open)
|
||||
{
|
||||
transaction.Rollback();
|
||||
if (!string.IsNullOrEmpty(context))
|
||||
{
|
||||
logger.LogDebug("Transaction rolled back for {Context}", context);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
logger.LogWarning(
|
||||
"Could not rollback transaction for {Context}: connection is {State}",
|
||||
context ?? "operation",
|
||||
connection?.State.ToString() ?? "null"
|
||||
);
|
||||
}
|
||||
}
|
||||
catch (Exception rollbackEx)
|
||||
{
|
||||
logger.LogWarning(
|
||||
rollbackEx,
|
||||
"Error during transaction rollback for {Context}: {Message}",
|
||||
context ?? "operation",
|
||||
rollbackEx.Message
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
58
util/Seeder/Migration/Utils/DbSeederConstants.cs
Normal file
58
util/Seeder/Migration/Utils/DbSeederConstants.cs
Normal file
@@ -0,0 +1,58 @@
|
||||
namespace Bit.Seeder.Migration.Utils;
|
||||
|
||||
/// <summary>
|
||||
/// Constants used throughout the DbSeeder utility for database operations.
|
||||
/// </summary>
|
||||
public static class DbSeederConstants
|
||||
{
|
||||
/// <summary>
|
||||
/// Default sample size for type detection and data analysis.
|
||||
/// </summary>
|
||||
public const int DEFAULT_SAMPLE_SIZE = 100;
|
||||
|
||||
/// <summary>
|
||||
/// Default batch size for bulk insert operations.
|
||||
/// </summary>
|
||||
public const int DEFAULT_BATCH_SIZE = 1000;
|
||||
|
||||
/// <summary>
|
||||
/// Large batch size for optimized bulk operations (use with caution for max_allowed_packet limits).
|
||||
/// </summary>
|
||||
public const int LARGE_BATCH_SIZE = 5000;
|
||||
|
||||
/// <summary>
|
||||
/// Default connection timeout in seconds.
|
||||
/// </summary>
|
||||
public const int DEFAULT_CONNECTION_TIMEOUT = 30;
|
||||
|
||||
/// <summary>
|
||||
/// Default command timeout in seconds for regular operations.
|
||||
/// </summary>
|
||||
public const int DEFAULT_COMMAND_TIMEOUT = 60;
|
||||
|
||||
/// <summary>
|
||||
/// Extended command timeout in seconds for large batch operations.
|
||||
/// </summary>
|
||||
public const int LARGE_BATCH_COMMAND_TIMEOUT = 300; // 5 minutes
|
||||
|
||||
/// <summary>
|
||||
/// Default maximum pool size for database connections.
|
||||
/// </summary>
|
||||
public const int DEFAULT_MAX_POOL_SIZE = 100;
|
||||
|
||||
/// <summary>
|
||||
/// Threshold for enabling detailed progress logging (row count).
|
||||
/// Operations with fewer rows may use simpler logging.
|
||||
/// </summary>
|
||||
public const int LOGGING_THRESHOLD = 1000;
|
||||
|
||||
/// <summary>
|
||||
/// Batch size for progress reporting during long-running operations.
|
||||
/// </summary>
|
||||
public const int PROGRESS_REPORTING_INTERVAL = 10000;
|
||||
|
||||
/// <summary>
|
||||
/// Placeholder text for redacting passwords in connection strings for safe logging.
|
||||
/// </summary>
|
||||
public const string REDACTED_PASSWORD = "***REDACTED***";
|
||||
}
|
||||
79
util/Seeder/Migration/Utils/IdentifierValidator.cs
Normal file
79
util/Seeder/Migration/Utils/IdentifierValidator.cs
Normal file
@@ -0,0 +1,79 @@
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace Bit.Seeder.Migration.Utils;
|
||||
|
||||
/// <summary>
|
||||
/// Validates SQL identifiers (table names, column names, schema names) to prevent SQL injection.
|
||||
/// </summary>
|
||||
public static class IdentifierValidator
|
||||
{
|
||||
// Regex pattern for valid SQL identifiers: must start with letter or underscore,
|
||||
// followed by letters, numbers, or underscores
|
||||
private static readonly Regex ValidIdentifierPattern = new Regex(
|
||||
@"^[a-zA-Z_][a-zA-Z0-9_]*$",
|
||||
RegexOptions.Compiled
|
||||
);
|
||||
|
||||
// Maximum reasonable length for identifiers (most databases have limits around 128-255)
|
||||
private const int MaxIdentifierLength = 128;
|
||||
|
||||
/// <summary>
|
||||
/// Validates a SQL identifier (table name, column name, schema name).
|
||||
/// </summary>
|
||||
/// <param name="identifier">The identifier to validate</param>
|
||||
/// <returns>True if the identifier is valid, false otherwise</returns>
|
||||
public static bool IsValid(string? identifier)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(identifier))
|
||||
return false;
|
||||
|
||||
if (identifier.Length > MaxIdentifierLength)
|
||||
return false;
|
||||
|
||||
return ValidIdentifierPattern.IsMatch(identifier);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validates a SQL identifier and throws an exception if invalid.
|
||||
/// </summary>
|
||||
/// <param name="identifier">The identifier to validate</param>
|
||||
/// <param name="identifierType">The type of identifier (e.g., "table name", "column name")</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the identifier is invalid</exception>
|
||||
public static void ValidateOrThrow(string? identifier, string identifierType = "identifier")
|
||||
{
|
||||
if (!IsValid(identifier))
|
||||
{
|
||||
throw new ArgumentException(
|
||||
$"Invalid {identifierType}: '{identifier}'. " +
|
||||
$"Identifiers must start with a letter or underscore, " +
|
||||
$"contain only letters, numbers, and underscores, " +
|
||||
$"and be no longer than {MaxIdentifierLength} characters.",
|
||||
nameof(identifier)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Validates multiple identifiers and throws an exception if any are invalid.
|
||||
/// </summary>
|
||||
/// <param name="identifiers">The identifiers to validate</param>
|
||||
/// <param name="identifierType">The type of identifiers (e.g., "column names")</param>
|
||||
/// <exception cref="ArgumentException">Thrown when any identifier is invalid</exception>
|
||||
public static void ValidateAllOrThrow(IEnumerable<string> identifiers, string identifierType = "identifiers")
|
||||
{
|
||||
foreach (var identifier in identifiers)
|
||||
{
|
||||
ValidateOrThrow(identifier, identifierType);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Filters a list of identifiers to only include valid ones.
|
||||
/// </summary>
|
||||
/// <param name="identifiers">The identifiers to filter</param>
|
||||
/// <returns>A list of valid identifiers</returns>
|
||||
public static List<string> FilterValid(IEnumerable<string> identifiers)
|
||||
{
|
||||
return identifiers.Where(IsValid).ToList();
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,8 @@ public static class SecuritySanitizer
|
||||
|
||||
foreach (var (key, value) in configDict)
|
||||
{
|
||||
if (SensitiveFields.Contains(key.ToLower()))
|
||||
// Use case-insensitive comparison with proper culture handling
|
||||
if (SensitiveFields.Any(field => field.Equals(key, StringComparison.OrdinalIgnoreCase)))
|
||||
{
|
||||
sanitized[key] = value != null ? MaskPassword(value.ToString() ?? string.Empty) : string.Empty;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user