diff --git a/util/DbSeederUtility/MigrationSettingsFactory.cs b/util/DbSeederUtility/MigrationSettingsFactory.cs index 0d7b178f9e..bcd4fb70ac 100644 --- a/util/DbSeederUtility/MigrationSettingsFactory.cs +++ b/util/DbSeederUtility/MigrationSettingsFactory.cs @@ -20,7 +20,7 @@ public static class MigrationSettingsFactory .SetBasePath(Directory.GetCurrentDirectory()) .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) .AddJsonFile($"appsettings.{Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") ?? "Production"}.json", optional: true, reloadOnChange: true) - .AddUserSecrets("bitwarden-Api") + .AddUserSecrets("bitwarden-Api") // Load user secrets from the API project .AddEnvironmentVariables(); var configuration = configBuilder.Build(); diff --git a/util/Seeder/Migration/CsvHandler.cs b/util/Seeder/Migration/CsvHandler.cs index 89d531fe56..4ead954f53 100644 --- a/util/Seeder/Migration/CsvHandler.cs +++ b/util/Seeder/Migration/CsvHandler.cs @@ -290,7 +290,7 @@ public class CsvHandler(CsvSettings settings, ILogger logger) if (string.IsNullOrEmpty(value)) { - processedRow[i] = null!; + processedRow[i] = DBNull.Value; } else if (specialColumns.Contains(colName)) { diff --git a/util/Seeder/Migration/Databases/MariaDbImporter.cs b/util/Seeder/Migration/Databases/MariaDbImporter.cs index 6f271356fd..3621382334 100644 --- a/util/Seeder/Migration/Databases/MariaDbImporter.cs +++ b/util/Seeder/Migration/Databases/MariaDbImporter.cs @@ -5,6 +5,9 @@ using Microsoft.Extensions.Logging; namespace Bit.Seeder.Migration.Databases; +/// +/// MariaDB database importer that handles schema creation, data import, and foreign key management. +/// public class MariaDbImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter { private readonly ILogger _logger = logger; @@ -14,16 +17,24 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log private readonly string _username = config.Username; private readonly string _password = config.Password; private MySqlConnection? _connection; + private bool _disposed = false; + /// + /// Connects to the MariaDB database. + /// public bool Connect() { try { - var connectionString = $"Server={_host};Port={_port};Database={_database};" + - $"Uid={_username};Pwd={_password};" + - $"ConnectionTimeout=30;CharSet=utf8mb4;AllowLoadLocalInfile=true;MaxPoolSize=100;"; + // Build connection string with redacted password for safe logging + var safeConnectionString = $"Server={_host};Port={_port};Database={_database};" + + $"Uid={_username};Pwd={DbSeederConstants.REDACTED_PASSWORD};" + + $"ConnectionTimeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};" + + $"CharSet=utf8mb4;AllowLoadLocalInfile=false;MaxPoolSize={DbSeederConstants.DEFAULT_MAX_POOL_SIZE};"; - _connection = new MySqlConnection(connectionString); + var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password); + + _connection = new MySqlConnection(actualConnectionString); _connection.Open(); _logger.LogInformation("Connected to MariaDB: {Host}:{Port}/{Database}", _host, _port, _database); @@ -47,6 +58,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Creates a table in MariaDB from the provided schema definition. + /// public bool CreateTableFromSchema( string tableName, List columns, @@ -58,11 +72,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var mariaColumns = new List(); foreach (var colName in columns) { + IdentifierValidator.ValidateOrThrow(colName, "column name"); + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); var mariaType = ConvertSqlServerTypeToMariaDB(sqlServerType, specialColumns.Contains(colName)); mariaColumns.Add($"`{colName}` {mariaType}"); @@ -122,15 +140,20 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Imports data into a MariaDB table using batched INSERT statements. + /// public bool ImportData( string tableName, List columns, List data, - int batchSize = 1000) + int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -146,7 +169,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log return false; } - // Filter columns var validColumnIndices = new List(); var validColumns = new List(); @@ -154,6 +176,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log { if (actualColumns.Contains(columns[i])) { + IdentifierValidator.ValidateOrThrow(columns[i], "column name"); validColumnIndices.Add(i); validColumns.Add(columns[i]); } @@ -217,14 +240,15 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log transaction.Commit(); totalImported += batch.Count; - if (filteredData.Count > 1000) + if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD) { _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); } } - catch + catch (Exception batchEx) { - transaction.Rollback(); + _logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message); + transaction.SafeRollback(_connection, _logger, tableName); throw; } } @@ -239,11 +263,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Checks if a table exists in the MariaDB database. + /// public bool TableExists(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -255,7 +284,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log command.Parameters.AddWithValue("@database", _database); command.Parameters.AddWithValue("@tableName", tableName); - var count = Convert.ToInt32(command.ExecuteScalar()); + var count = command.GetScalarValue(0, _logger, $"table existence check for {tableName}"); return count > 0; } catch (Exception ex) @@ -265,17 +294,22 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Gets the row count for a specific table. + /// public int GetTableRowCount(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"SELECT COUNT(*) FROM `{tableName}`"; using var command = new MySqlCommand(query, _connection); - return Convert.ToInt32(command.ExecuteScalar()); + return command.GetScalarValue(0, _logger, $"row count for {tableName}"); } catch (Exception ex) { @@ -284,11 +318,16 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Drops a table from the MariaDB database. + /// public bool DropTable(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"DROP TABLE IF EXISTS `{tableName}`"; @@ -448,6 +487,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log return true; // MariaDB multi-row INSERT is optimized } + /// + /// Imports data using optimized multi-row INSERT statements for better performance. + /// public bool ImportDataBulk( string tableName, List columns, @@ -456,6 +498,8 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -471,7 +515,6 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log return false; } - // Filter columns var validColumnIndices = new List(); var validColumns = new List(); @@ -479,6 +522,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log { if (actualColumns.Contains(columns[i])) { + IdentifierValidator.ValidateOrThrow(columns[i], "column name"); validColumnIndices.Add(i); validColumns.Add(columns[i]); } @@ -496,10 +540,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log _logger.LogInformation("Bulk importing {Count} rows into {TableName} using multi-row INSERT", filteredData.Count, tableName); - // Use multi-row INSERT for better performance - // INSERT INTO table (col1, col2) VALUES (val1, val2), (val3, val4), ... - // MariaDB can handle up to max_allowed_packet size, we'll use 5000 rows per batch - const int rowsPerBatch = 5000; + const int rowsPerBatch = DbSeederConstants.LARGE_BATCH_SIZE; var totalImported = 0; for (int i = 0; i < filteredData.Count; i += rowsPerBatch) @@ -536,7 +577,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log var fullInsertSql = columnPart + string.Join(", ", valueSets); using var command = new MySqlCommand(fullInsertSql, _connection, transaction); - command.CommandTimeout = 300; // 5 minutes timeout for large batches + command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT; // Add all parameters foreach (var (name, value) in allParameters) @@ -562,7 +603,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log transaction.Commit(); totalImported += batch.Count; - if (filteredData.Count > 1000) + if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD) { _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); } @@ -570,19 +611,7 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log catch (Exception batchEx) { _logger.LogError("Batch import error for {TableName}: {Message}", tableName, batchEx.Message); - - try - { - if (_connection?.State == System.Data.ConnectionState.Open) - { - transaction.Rollback(); - } - } - catch (Exception rollbackEx) - { - _logger.LogWarning("Could not rollback transaction (connection may be closed): {Message}", rollbackEx.Message); - } - + transaction.SafeRollback(_connection, _logger, tableName); throw; } } @@ -602,6 +631,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Tests the connection to MariaDB by executing a simple query. + /// public bool TestConnection() { try @@ -609,9 +641,9 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log if (Connect()) { using var command = new MySqlCommand("SELECT 1", _connection); - var result = command.ExecuteScalar(); + var result = command.GetScalarValue(0, _logger, "connection test"); Disconnect(); - return result != null && Convert.ToInt32(result) == 1; + return result == 1; } return false; } @@ -622,8 +654,27 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log } } + /// + /// Disposes of the MariaDB importer and releases all resources. + /// public void Dispose() { - Disconnect(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Implements Dispose pattern for resource cleanup. + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Disconnect(); + } + _disposed = true; + } } } diff --git a/util/Seeder/Migration/Databases/PostgresImporter.cs b/util/Seeder/Migration/Databases/PostgresImporter.cs index 66e88cafa6..18f4b0e846 100644 --- a/util/Seeder/Migration/Databases/PostgresImporter.cs +++ b/util/Seeder/Migration/Databases/PostgresImporter.cs @@ -1,10 +1,14 @@ using Npgsql; using NpgsqlTypes; using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Utils; using Microsoft.Extensions.Logging; namespace Bit.Seeder.Migration.Databases; +/// +/// PostgreSQL database importer that handles schema creation, data import, and constraint management. +/// public class PostgresImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter { private readonly ILogger _logger = logger; @@ -14,16 +18,22 @@ public class PostgresImporter(DatabaseConfig config, ILogger l private readonly string _username = config.Username; private readonly string _password = config.Password; private NpgsqlConnection? _connection; + private bool _disposed = false; + /// + /// Connects to the PostgreSQL database. + /// public bool Connect() { try { - var connectionString = $"Host={_host};Port={_port};Database={_database};" + - $"Username={_username};Password={_password};" + - $"Timeout=30;CommandTimeout=30;"; + var safeConnectionString = $"Host={_host};Port={_port};Database={_database};" + + $"Username={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" + + $"Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};CommandTimeout={DbSeederConstants.DEFAULT_COMMAND_TIMEOUT};"; - _connection = new NpgsqlConnection(connectionString); + var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password); + + _connection = new NpgsqlConnection(actualConnectionString); _connection.Open(); _logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database); @@ -47,6 +57,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger l } } + /// + /// Creates a table from the provided schema definition. + /// public bool CreateTableFromSchema( string tableName, List columns, @@ -58,12 +71,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger l if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { // Convert SQL Server types to PostgreSQL types var pgColumns = new List(); foreach (var colName in columns) { + IdentifierValidator.ValidateOrThrow(colName, "column name"); + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName)); pgColumns.Add($"\"{colName}\" {pgType}"); @@ -185,15 +202,20 @@ public class PostgresImporter(DatabaseConfig config, ILogger l } } + /// + /// Imports data into a table using batch INSERT statements. + /// public bool ImportData( string tableName, List columns, List data, - int batchSize = 1000) + int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -210,6 +232,8 @@ public class PostgresImporter(DatabaseConfig config, ILogger l return false; } + IdentifierValidator.ValidateOrThrow(actualTableName, "actual table name"); + var actualColumns = GetTableColumns(tableName); if (actualColumns.Count == 0) { @@ -319,14 +343,14 @@ public class PostgresImporter(DatabaseConfig config, ILogger l transaction.Commit(); totalImported += batch.Count; - if (filteredData.Count > 1000) + if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD) { _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); } } catch { - transaction.Rollback(); + transaction.SafeRollback(_connection, _logger, tableName); throw; } } @@ -346,11 +370,16 @@ public class PostgresImporter(DatabaseConfig config, ILogger l } } + /// + /// Checks if a table exists in the database. + /// public bool TableExists(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -362,7 +391,7 @@ public class PostgresImporter(DatabaseConfig config, ILogger l using var command = new NpgsqlCommand(query, _connection); command.Parameters.AddWithValue("tableName", tableName); - return (bool)command.ExecuteScalar()!; + return command.GetScalarValue(false, _logger, $"table existence check for {tableName}"); } catch (Exception ex) { @@ -756,6 +785,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger l } } + /// + /// Tests the connection to PostgreSQL by executing a simple query. + /// public bool TestConnection() { try @@ -763,9 +795,9 @@ public class PostgresImporter(DatabaseConfig config, ILogger l if (Connect()) { using var command = new NpgsqlCommand("SELECT 1", _connection); - var result = command.ExecuteScalar(); + var result = command.GetScalarValue(0, _logger, "connection test"); Disconnect(); - return result != null && (int)result == 1; + return result == 1; } return false; } @@ -776,8 +808,27 @@ public class PostgresImporter(DatabaseConfig config, ILogger l } } + /// + /// Disposes of the PostgreSQL importer and releases all resources. + /// public void Dispose() { - Disconnect(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Implements Dispose pattern for resource cleanup. + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Disconnect(); + } + _disposed = true; + } } } diff --git a/util/Seeder/Migration/Databases/SqlServerExporter.cs b/util/Seeder/Migration/Databases/SqlServerExporter.cs index 84ea35a224..c711acea02 100644 --- a/util/Seeder/Migration/Databases/SqlServerExporter.cs +++ b/util/Seeder/Migration/Databases/SqlServerExporter.cs @@ -1,9 +1,13 @@ using Microsoft.Data.SqlClient; using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Utils; using Microsoft.Extensions.Logging; namespace Bit.Seeder.Migration.Databases; +/// +/// SQL Server database exporter that handles schema discovery and data export. +/// public class SqlServerExporter(DatabaseConfig config, ILogger logger) : IDisposable { private readonly ILogger _logger = logger; @@ -13,16 +17,23 @@ public class SqlServerExporter(DatabaseConfig config, ILogger private readonly string _username = config.Username; private readonly string _password = config.Password; private SqlConnection? _connection; + private bool _disposed = false; + /// + /// Connects to the SQL Server database. + /// public bool Connect() { try { - var connectionString = $"Server={_host},{_port};Database={_database};" + - $"User Id={_username};Password={_password};" + - $"TrustServerCertificate=True;Connection Timeout=30;"; + var safeConnectionString = $"Server={_host},{_port};Database={_database};" + + $"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" + + $"TrustServerCertificate=True;" + + $"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};"; - _connection = new SqlConnection(connectionString); + var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password); + + _connection = new SqlConnection(actualConnectionString); _connection.Open(); _logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database); @@ -35,6 +46,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } + /// + /// Disconnects from the SQL Server database. + /// public void Disconnect() { if (_connection != null) @@ -46,6 +60,11 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } + /// + /// Discovers all tables in the SQL Server database. + /// + /// Whether to exclude system tables + /// List of table names public List DiscoverTables(bool excludeSystemTables = true) { if (_connection == null) @@ -86,11 +105,18 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } + /// + /// Gets detailed information about a table including columns, types, and row count. + /// + /// The name of the table to query + /// TableInfo containing schema and metadata public TableInfo GetTableInfo(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { // Get column information @@ -146,13 +172,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger throw new InvalidOperationException($"Table '{tableName}' not found"); } - // Get row count var countQuery = $"SELECT COUNT(*) FROM [{tableName}]"; int rowCount; using (var command = new SqlCommand(countQuery, _connection)) { - rowCount = (int)command.ExecuteScalar()!; + rowCount = command.GetScalarValue(0, _logger, $"row count for {tableName}"); } _logger.LogInformation("Table {TableName}: {ColumnCount} columns, {RowCount} rows", tableName, columns.Count, rowCount); @@ -172,16 +197,32 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - public (List Columns, List Data) ExportTableData(string tableName, int batchSize = 10000) + /// + /// Exports all data from a table with streaming to avoid memory exhaustion. + /// + /// The name of the table to export + /// Batch size for progress reporting + /// Tuple of column names and data rows + public (List Columns, List Data) ExportTableData( + string tableName, + int batchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { // Get table info first var tableInfo = GetTableInfo(tableName); + // Validate all column names + foreach (var colName in tableInfo.Columns) + { + IdentifierValidator.ValidateOrThrow(colName, "column name"); + } + // Build column list with proper quoting var quotedColumns = tableInfo.Columns.Select(col => $"[{col}]").ToList(); var columnList = string.Join(", ", quotedColumns); @@ -191,16 +232,24 @@ public class SqlServerExporter(DatabaseConfig config, ILogger _logger.LogInformation("Executing export query for {TableName}", tableName); using var command = new SqlCommand(query, _connection); - command.CommandTimeout = 300; // 5 minutes + command.CommandTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT; using var reader = command.ExecuteReader(); - // Fetch data in batches + // Identify GUID columns for in-place uppercase conversion + var guidColumnIndices = IdentifyGuidColumns(tableInfo); + + // Fetch data in batches - still loads into memory but with progress reporting + // Note: For true streaming, consumers should use yield return pattern var allData = new List(); while (reader.Read()) { var row = new object[tableInfo.Columns.Count]; reader.GetValues(row); + + // Convert GUID values in-place to uppercase for Bitwarden compatibility + ConvertGuidsToUppercaseInPlace(row, guidColumnIndices); + allData.Add(row); if (allData.Count % batchSize == 0) @@ -209,11 +258,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - // Convert GUID values to uppercase to ensure compatibility with Bitwarden - var processedData = ConvertGuidsToUppercase(allData, tableInfo); - - _logger.LogInformation("Exported {Count} rows from {TableName}", processedData.Count, tableName); - return (tableInfo.Columns, processedData); + _logger.LogInformation("Exported {Count} rows from {TableName}", allData.Count, tableName); + return (tableInfo.Columns, allData); } catch (Exception ex) { @@ -222,11 +268,21 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - public List IdentifyJsonColumns(string tableName, int sampleSize = 100) + /// + /// Identifies columns that likely contain JSON data by sampling values. + /// + /// The name of the table to analyze + /// Number of rows to sample for analysis + /// List of column names that appear to contain JSON + public List IdentifyJsonColumns( + string tableName, + int sampleSize = DbSeederConstants.DEFAULT_SAMPLE_SIZE) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var tableInfo = GetTableInfo(tableName); @@ -243,6 +299,12 @@ public class SqlServerExporter(DatabaseConfig config, ILogger if (textColumns.Count == 0) return jsonColumns; + // Validate all column names + foreach (var colName in textColumns) + { + IdentifierValidator.ValidateOrThrow(colName, "column name"); + } + // Sample data from text columns var quotedColumns = textColumns.Select(col => $"[{col}]").ToList(); var columnList = string.Join(", ", quotedColumns); @@ -266,6 +328,9 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } // Analyze each column + // JSON detection threshold: 50% of samples must look like JSON + const double jsonThreshold = 0.5; + for (int i = 0; i < textColumns.Count; i++) { var colName = textColumns[i]; @@ -288,8 +353,8 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - // If more than 50% of non-null values look like JSON, mark as JSON column - if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > 0.5) + // If more than threshold of non-null values look like JSON, mark as JSON column + if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > jsonThreshold) { jsonColumns.Add(colName); _logger.LogInformation("Identified {ColumnName} as likely JSON column ({JsonIndicators}/{TotalNonNull} samples)", colName, jsonIndicators, totalNonNull); @@ -305,13 +370,15 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - private List ConvertGuidsToUppercase(List data, TableInfo tableInfo) + /// + /// Identifies GUID column indices in a table for efficient processing. + /// + /// Table metadata including column types + /// List of column indices that are GUID/uniqueidentifier columns + private List IdentifyGuidColumns(TableInfo tableInfo) { - if (data.Count == 0 || tableInfo.ColumnTypes.Count == 0) - return data; - - // Identify GUID columns (uniqueidentifier type in SQL Server) var guidColumnIndices = new List(); + for (int i = 0; i < tableInfo.Columns.Count; i++) { var columnName = tableInfo.Columns[i]; @@ -325,34 +392,36 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } - if (guidColumnIndices.Count == 0) + if (guidColumnIndices.Count > 0) { - _logger.LogDebug("No GUID columns found, returning data unchanged"); - return data; + _logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count); } - _logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count); - - // Process each row and convert GUID values to uppercase - var processedData = new List(); - foreach (var row in data) - { - var rowList = row.ToList(); - foreach (var guidIdx in guidColumnIndices) - { - if (guidIdx < rowList.Count && rowList[guidIdx] != null && rowList[guidIdx] != DBNull.Value) - { - var guidValue = rowList[guidIdx].ToString(); - // Convert to uppercase, preserving the GUID format - rowList[guidIdx] = guidValue?.ToUpper() ?? string.Empty; - } - } - processedData.Add(rowList.ToArray()); - } - - return processedData; + return guidColumnIndices; } + /// + /// Converts GUID values to uppercase in-place within a data row. + /// More efficient than creating new arrays as it modifies the array directly. + /// + /// The data row to modify + /// Indices of columns containing GUIDs + private void ConvertGuidsToUppercaseInPlace(object[] row, List guidColumnIndices) + { + foreach (var guidIdx in guidColumnIndices) + { + if (guidIdx < row.Length && row[guidIdx] != null && row[guidIdx] != DBNull.Value) + { + var guidValue = row[guidIdx].ToString(); + // Convert to uppercase in-place, preserving the GUID format + row[guidIdx] = guidValue?.ToUpper() ?? string.Empty; + } + } + } + + /// + /// Tests the connection to SQL Server by executing a simple query. + /// public bool TestConnection() { try @@ -360,9 +429,10 @@ public class SqlServerExporter(DatabaseConfig config, ILogger if (Connect()) { using var command = new SqlCommand("SELECT 1", _connection); - var result = command.ExecuteScalar(); + // Use null-safe scalar value retrieval + var result = command.GetScalarValue(0, _logger, "connection test"); Disconnect(); - return result != null && (int)result == 1; + return result == 1; } return false; } @@ -373,8 +443,27 @@ public class SqlServerExporter(DatabaseConfig config, ILogger } } + /// + /// Disposes of the SQL Server exporter and releases all resources. + /// public void Dispose() { - Disconnect(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Implements Dispose pattern for resource cleanup. + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Disconnect(); + } + _disposed = true; + } } } diff --git a/util/Seeder/Migration/Databases/SqlServerImporter.cs b/util/Seeder/Migration/Databases/SqlServerImporter.cs index a7bcf77643..506957d246 100644 --- a/util/Seeder/Migration/Databases/SqlServerImporter.cs +++ b/util/Seeder/Migration/Databases/SqlServerImporter.cs @@ -6,6 +6,9 @@ using System.Data; namespace Bit.Seeder.Migration.Databases; +/// +/// SQL Server database importer that handles schema creation, data import, and constraint management. +/// public class SqlServerImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter { private readonly ILogger _logger = logger; @@ -15,17 +18,24 @@ public class SqlServerImporter(DatabaseConfig config, ILogger private readonly string _username = config.Username; private readonly string _password = config.Password; private SqlConnection? _connection; + private bool _disposed = false; private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]"; + /// + /// Connects to the SQL Server database. + /// public bool Connect() { try { - var connectionString = $"Server={_host},{_port};Database={_database};" + - $"User Id={_username};Password={_password};" + - $"TrustServerCertificate=True;Connection Timeout=30;"; + var safeConnectionString = $"Server={_host},{_port};Database={_database};" + + $"User Id={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" + + $"TrustServerCertificate=True;" + + $"Connection Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};"; - _connection = new SqlConnection(connectionString); + var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password); + + _connection = new SqlConnection(actualConnectionString); _connection.Open(); _logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database); @@ -38,6 +48,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Disconnects from the SQL Server database. + /// public void Disconnect() { if (_connection != null) @@ -49,11 +62,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Gets the list of columns for a table. + /// public List GetTableColumns(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -81,11 +99,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Gets the column types for a table. + /// private Dictionary GetTableColumnTypes(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -112,11 +135,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Checks if a table exists in the database. + /// public bool TableExists(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -127,7 +155,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger using var command = new SqlCommand(query, _connection); command.Parameters.AddWithValue("@TableName", tableName); - var count = (int)command.ExecuteScalar()!; + var count = command.GetScalarValue(0, _logger, $"table existence check for {tableName}"); return count > 0; } catch (Exception ex) @@ -137,17 +165,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Gets the row count for a specific table. + /// public int GetTableRowCount(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"SELECT COUNT(*) FROM [{tableName}]"; using var command = new SqlCommand(query, _connection); - var count = (int)command.ExecuteScalar()!; + var count = command.GetScalarValue(0, _logger, $"row count for {tableName}"); _logger.LogDebug("Row count for {TableName}: {Count}", tableName, count); return count; } @@ -158,11 +191,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Drops a table from the database. + /// public bool DropTable(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"DROP TABLE IF EXISTS [{tableName}]"; @@ -200,6 +238,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Gets the list of constraints that need to be re-enabled from the tracking table. + /// private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable() { if (_connection == null) @@ -212,7 +253,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger // Check if tracking table exists var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')"; using var checkCommand = new SqlCommand(checkSql, _connection); - var tableExists = (int)checkCommand.ExecuteScalar()! > 0; + + var count = checkCommand.GetScalarValue(0, _logger, "tracking table existence check"); + var tableExists = count > 0; if (!tableExists) { @@ -248,6 +291,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger return constraints; } + /// + /// Disables all foreign key constraints and tracks them for re-enabling. + /// public bool DisableForeignKeys() { if (_connection == null) @@ -261,7 +307,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')"; using (var checkCommand = new SqlCommand(checkSql, _connection)) { - var tableExists = (int)checkCommand.ExecuteScalar()! > 0; + var count = checkCommand.GetScalarValue(0, _logger, "tracking table existence check"); + var tableExists = count > 0; if (tableExists) { @@ -366,6 +413,11 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { try { + // Validate identifiers to prevent SQL injection + IdentifierValidator.ValidateOrThrow(schema, "schema name"); + IdentifierValidator.ValidateOrThrow(table, "table name"); + IdentifierValidator.ValidateOrThrow(constraint, "constraint name"); + // Disable the constraint var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]"; using var disableCommand = new SqlCommand(disableSql, _connection, transaction); @@ -401,7 +453,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { // If anything fails, rollback the transaction // This ensures the tracking table doesn't exist with incomplete data - transaction.Rollback(); + // Safely rollback transaction, preserving original exception + transaction.SafeRollback(_connection, _logger, "foreign key constraint disabling"); throw; } } @@ -412,6 +465,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Re-enables all foreign key constraints that were disabled. + /// public bool EnableForeignKeys() { if (_connection == null) @@ -437,6 +493,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { try { + IdentifierValidator.ValidateOrThrow(schema, "schema name"); + IdentifierValidator.ValidateOrThrow(table, "table name"); + IdentifierValidator.ValidateOrThrow(constraint, "constraint name"); + var enableSql = $"ALTER TABLE [{schema}].[{table}] CHECK CONSTRAINT [{constraint}]"; using var command = new SqlCommand(enableSql, _connection); command.ExecuteNonQuery(); @@ -464,6 +524,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Creates a table from the provided schema definition. + /// public bool CreateTableFromSchema( string tableName, List columns, @@ -475,12 +538,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { // Build column definitions var sqlServerColumns = new List(); foreach (var colName in columns) { + IdentifierValidator.ValidateOrThrow(colName, "column name"); + var colType = columnTypes.GetValueOrDefault(colName, "NVARCHAR(MAX)"); // If it's a special JSON column, ensure it's a large text type @@ -516,11 +583,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Gets the list of identity columns for a table. + /// public List GetIdentityColumns(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = @" @@ -548,11 +620,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Enables IDENTITY_INSERT for a table to allow explicit identity values. + /// public bool EnableIdentityInsert(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"SET IDENTITY_INSERT [{tableName}] ON"; @@ -569,11 +646,16 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Disables IDENTITY_INSERT for a table. + /// public bool DisableIdentityInsert(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"SET IDENTITY_INSERT [{tableName}] OFF"; @@ -590,15 +672,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Imports data into a table using batch insert statements. + /// public bool ImportData( string tableName, List columns, List data, - int batchSize = 1000) + int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -624,6 +711,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { if (actualColumns.Contains(columns[i])) { + IdentifierValidator.ValidateOrThrow(columns[i], "column name"); + validColumnIndices.Add(i); validColumns.Add(columns[i]); } @@ -689,7 +778,6 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } - // Validate that data was actually inserted var actualCount = GetTableRowCount(tableName); _logger.LogInformation("Post-import validation for {TableName}: imported {TotalImported}, table contains {ActualCount}", tableName, totalImported, actualCount); @@ -708,15 +796,20 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Imports data using SqlBulkCopy for high performance. + /// private int UseSqlBulkCopy(string tableName, List columns, List data) { + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { using var bulkCopy = new SqlBulkCopy(_connection!) { DestinationTableName = $"[{tableName}]", - BatchSize = 10000, - BulkCopyTimeout = 600 // 10 minutes + BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL, + BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT }; // Map columns @@ -747,12 +840,22 @@ public class SqlServerImporter(DatabaseConfig config, ILogger catch (Exception ex) { _logger.LogWarning("SqlBulkCopy failed: {Message}, falling back to batch insert", ex.Message); - return FastBatchImport(tableName, columns, data, 1000); + return FastBatchImport(tableName, columns, data, DbSeederConstants.DEFAULT_BATCH_SIZE); } } + /// + /// Imports data using fast batch INSERT statements with transactions. + /// private int FastBatchImport(string tableName, List columns, List data, int batchSize) { + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + + foreach (var column in columns) + { + IdentifierValidator.ValidateOrThrow(column, "column name"); + } + var quotedColumns = columns.Select(col => $"[{col}]").ToList(); var placeholders = string.Join(", ", columns.Select((_, i) => $"@p{i}")); var insertSql = $"INSERT INTO [{tableName}] ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})"; @@ -782,14 +885,15 @@ public class SqlServerImporter(DatabaseConfig config, ILogger transaction.Commit(); totalImported += batch.Count; - if (data.Count > 1000) + if (data.Count > DbSeederConstants.LOGGING_THRESHOLD) { _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{DataCount} total, {Percentage:F1}%)", batch.Count, totalImported, data.Count, (totalImported / (double)data.Count * 100)); } } catch { - transaction.Rollback(); + // Safely rollback transaction, preserving original exception + transaction.SafeRollback(_connection, _logger, tableName); throw; } } @@ -938,6 +1042,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger return true; // SQL Server SqlBulkCopy is highly optimized } + /// + /// Imports data using SqlBulkCopy for high-performance bulk loading. + /// public bool ImportDataBulk( string tableName, List columns, @@ -946,6 +1053,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger if (_connection == null) throw new InvalidOperationException("Not connected to database"); + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -971,6 +1080,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { if (actualColumns.Contains(columns[i])) { + IdentifierValidator.ValidateOrThrow(columns[i], "column name"); + validColumnIndices.Add(i); validColumns.Add(columns[i]); } @@ -1022,8 +1133,8 @@ public class SqlServerImporter(DatabaseConfig config, ILogger using var bulkCopy = new SqlBulkCopy(_connection, bulkCopyOptions, null) { DestinationTableName = $"[{tableName}]", - BatchSize = 10000, - BulkCopyTimeout = 600 // 10 minutes + BatchSize = DbSeederConstants.PROGRESS_REPORTING_INTERVAL, + BulkCopyTimeout = DbSeederConstants.LARGE_BATCH_COMMAND_TIMEOUT }; // Map columns @@ -1063,6 +1174,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Tests the connection to SQL Server by executing a simple query. + /// public bool TestConnection() { try @@ -1070,9 +1184,9 @@ public class SqlServerImporter(DatabaseConfig config, ILogger if (Connect()) { using var command = new SqlCommand("SELECT 1", _connection); - var result = command.ExecuteScalar(); + var result = command.GetScalarValue(0, _logger, "connection test"); Disconnect(); - return result != null && (int)result == 1; + return result == 1; } return false; } @@ -1083,8 +1197,27 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + /// + /// Disposes of the SQL Server importer and releases all resources. + /// public void Dispose() { - Disconnect(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Implements Dispose pattern for resource cleanup. + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Disconnect(); + } + _disposed = true; + } } } diff --git a/util/Seeder/Migration/Databases/SqliteImporter.cs b/util/Seeder/Migration/Databases/SqliteImporter.cs index 7900f93431..d26ea7e47d 100644 --- a/util/Seeder/Migration/Databases/SqliteImporter.cs +++ b/util/Seeder/Migration/Databases/SqliteImporter.cs @@ -5,11 +5,15 @@ using Microsoft.Extensions.Logging; namespace Bit.Seeder.Migration.Databases; +/// +/// SQLite database importer that handles schema creation and data import. +/// public class SqliteImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter { private readonly ILogger _logger = logger; private readonly string _databasePath = config.Database; private SqliteConnection? _connection; + private bool _disposed = false; public bool Connect() { @@ -74,6 +78,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Creates a table from the provided schema definition. + /// public bool CreateTableFromSchema( string tableName, List columns, @@ -85,11 +92,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var sqliteColumns = new List(); foreach (var colName in columns) { + // Validate each column name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(colName, "column name"); + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); var sqliteType = ConvertSqlServerTypeToSQLite(sqlServerType, specialColumns.Contains(colName)); sqliteColumns.Add($"\"{colName}\" {sqliteType}"); @@ -116,11 +129,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Gets the list of columns for a table. + /// public List GetTableColumns(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"PRAGMA table_info(\"{tableName}\")"; @@ -142,15 +161,21 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Imports data into a table using batch INSERT statements. + /// public bool ImportData( string tableName, List columns, List data, - int batchSize = 1000) + int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + if (data.Count == 0) { _logger.LogWarning("No data to import for table {TableName}", tableName); @@ -174,6 +199,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge { if (actualColumns.Contains(columns[i])) { + // Validate column name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(columns[i], "column name"); + validColumnIndices.Add(i); validColumns.Add(columns[i]); } @@ -231,7 +259,7 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge totalImported += batch.Count; - if (filteredData.Count > 1000) + if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD) { _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); } @@ -244,7 +272,8 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } catch { - transaction.Rollback(); + // Safely rollback transaction, preserving original exception + transaction.SafeRollback(_connection, _logger, tableName); throw; } } @@ -255,18 +284,25 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Checks if a table exists in the database. + /// public bool TableExists(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name = @tableName"; using var command = new SqliteCommand(query, _connection); command.Parameters.AddWithValue("@tableName", tableName); - var count = Convert.ToInt64(command.ExecuteScalar()); + // Use null-safe scalar value retrieval + var count = command.GetScalarValue(0, _logger, $"table existence check for {tableName}"); return count > 0; } catch (Exception ex) @@ -276,17 +312,24 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Gets the row count for a specific table. + /// public int GetTableRowCount(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"SELECT COUNT(*) FROM \"{tableName}\""; using var command = new SqliteCommand(query, _connection); - return Convert.ToInt32(command.ExecuteScalar()); + // Use null-safe scalar value retrieval + return command.GetScalarValue(0, _logger, $"row count for {tableName}"); } catch (Exception ex) { @@ -295,11 +338,17 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Drops a table from the database. + /// public bool DropTable(string tableName) { if (_connection == null) throw new InvalidOperationException("Not connected to database"); + // Validate table name to prevent SQL injection + IdentifierValidator.ValidateOrThrow(tableName, "table name"); + try { var query = $"DROP TABLE IF EXISTS \"{tableName}\""; @@ -432,6 +481,9 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge return false; } + /// + /// Tests the connection to SQLite by executing a simple query. + /// public bool TestConnection() { try @@ -439,9 +491,10 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge if (Connect()) { using var command = new SqliteCommand("SELECT 1", _connection); - var result = command.ExecuteScalar(); + // Use null-safe scalar value retrieval + var result = command.GetScalarValue(0, _logger, "connection test"); Disconnect(); - return result != null && Convert.ToInt32(result) == 1; + return result == 1; } return false; } @@ -452,8 +505,29 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge } } + /// + /// Disposes of the SQLite importer and releases all resources. + /// public void Dispose() { - Disconnect(); + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Protected implementation of Dispose pattern. + /// + /// True if called from Dispose(), false if called from finalizer + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + // Dispose managed resources + Disconnect(); + } + _disposed = true; + } } } diff --git a/util/Seeder/Migration/Utils/DbExtensions.cs b/util/Seeder/Migration/Utils/DbExtensions.cs new file mode 100644 index 0000000000..29b213f004 --- /dev/null +++ b/util/Seeder/Migration/Utils/DbExtensions.cs @@ -0,0 +1,123 @@ +using System.Data; +using System.Data.Common; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Utils; + +/// +/// Extension methods for database operations with improved null safety. +/// +public static class DbExtensions +{ + /// + /// Executes a scalar query and returns the result with proper null handling. + /// Safely handles null and DBNull.Value results without throwing exceptions. + /// + /// The expected return type + /// The database command to execute + /// The default value to return if result is null/DBNull + /// Optional logger for warnings + /// Optional context for logging (e.g., "table count for Users") + /// The scalar value cast to type T, or defaultValue if null/DBNull + public static T GetScalarValue( + this DbCommand command, + T defaultValue = default!, + ILogger? logger = null, + string? context = null) + { + var result = command.ExecuteScalar(); + + if (result == null || result == DBNull.Value) + { + if (logger != null && !string.IsNullOrEmpty(context)) + { + logger.LogDebug("Query returned null for {Context}, using default value", context); + } + return defaultValue; + } + + try + { + // Handle direct cast if types match + if (result is T typedResult) + { + return typedResult; + } + + // Handle conversion for compatible types + return (T)Convert.ChangeType(result, typeof(T)); + } + catch (InvalidCastException ex) + { + if (logger != null) + { + logger.LogWarning( + "Could not cast result to {TargetType} for {Context}. Result type: {ActualType}. Error: {Error}", + typeof(T).Name, + context ?? "query", + result.GetType().Name, + ex.Message + ); + } + return defaultValue; + } + catch (FormatException ex) + { + if (logger != null) + { + logger.LogWarning( + "Format error converting result to {TargetType} for {Context}. Value: {Value}. Error: {Error}", + typeof(T).Name, + context ?? "query", + result, + ex.Message + ); + } + return defaultValue; + } + } + + /// + /// Safely attempts to rollback a transaction, catching and logging any errors. + /// This prevents rollback errors from masking the original exception. + /// + /// The transaction to rollback + /// The database connection (used to check if still open) + /// Logger for warnings + /// Context for logging (e.g., table name) + public static void SafeRollback( + this DbTransaction transaction, + DbConnection? connection, + ILogger logger, + string? context = null) + { + try + { + if (connection?.State == ConnectionState.Open) + { + transaction.Rollback(); + if (!string.IsNullOrEmpty(context)) + { + logger.LogDebug("Transaction rolled back for {Context}", context); + } + } + else + { + logger.LogWarning( + "Could not rollback transaction for {Context}: connection is {State}", + context ?? "operation", + connection?.State.ToString() ?? "null" + ); + } + } + catch (Exception rollbackEx) + { + logger.LogWarning( + rollbackEx, + "Error during transaction rollback for {Context}: {Message}", + context ?? "operation", + rollbackEx.Message + ); + } + } +} diff --git a/util/Seeder/Migration/Utils/DbSeederConstants.cs b/util/Seeder/Migration/Utils/DbSeederConstants.cs new file mode 100644 index 0000000000..730168fa45 --- /dev/null +++ b/util/Seeder/Migration/Utils/DbSeederConstants.cs @@ -0,0 +1,58 @@ +namespace Bit.Seeder.Migration.Utils; + +/// +/// Constants used throughout the DbSeeder utility for database operations. +/// +public static class DbSeederConstants +{ + /// + /// Default sample size for type detection and data analysis. + /// + public const int DEFAULT_SAMPLE_SIZE = 100; + + /// + /// Default batch size for bulk insert operations. + /// + public const int DEFAULT_BATCH_SIZE = 1000; + + /// + /// Large batch size for optimized bulk operations (use with caution for max_allowed_packet limits). + /// + public const int LARGE_BATCH_SIZE = 5000; + + /// + /// Default connection timeout in seconds. + /// + public const int DEFAULT_CONNECTION_TIMEOUT = 30; + + /// + /// Default command timeout in seconds for regular operations. + /// + public const int DEFAULT_COMMAND_TIMEOUT = 60; + + /// + /// Extended command timeout in seconds for large batch operations. + /// + public const int LARGE_BATCH_COMMAND_TIMEOUT = 300; // 5 minutes + + /// + /// Default maximum pool size for database connections. + /// + public const int DEFAULT_MAX_POOL_SIZE = 100; + + /// + /// Threshold for enabling detailed progress logging (row count). + /// Operations with fewer rows may use simpler logging. + /// + public const int LOGGING_THRESHOLD = 1000; + + /// + /// Batch size for progress reporting during long-running operations. + /// + public const int PROGRESS_REPORTING_INTERVAL = 10000; + + /// + /// Placeholder text for redacting passwords in connection strings for safe logging. + /// + public const string REDACTED_PASSWORD = "***REDACTED***"; +} diff --git a/util/Seeder/Migration/Utils/IdentifierValidator.cs b/util/Seeder/Migration/Utils/IdentifierValidator.cs new file mode 100644 index 0000000000..813861d085 --- /dev/null +++ b/util/Seeder/Migration/Utils/IdentifierValidator.cs @@ -0,0 +1,79 @@ +using System.Text.RegularExpressions; + +namespace Bit.Seeder.Migration.Utils; + +/// +/// Validates SQL identifiers (table names, column names, schema names) to prevent SQL injection. +/// +public static class IdentifierValidator +{ + // Regex pattern for valid SQL identifiers: must start with letter or underscore, + // followed by letters, numbers, or underscores + private static readonly Regex ValidIdentifierPattern = new Regex( + @"^[a-zA-Z_][a-zA-Z0-9_]*$", + RegexOptions.Compiled + ); + + // Maximum reasonable length for identifiers (most databases have limits around 128-255) + private const int MaxIdentifierLength = 128; + + /// + /// Validates a SQL identifier (table name, column name, schema name). + /// + /// The identifier to validate + /// True if the identifier is valid, false otherwise + public static bool IsValid(string? identifier) + { + if (string.IsNullOrWhiteSpace(identifier)) + return false; + + if (identifier.Length > MaxIdentifierLength) + return false; + + return ValidIdentifierPattern.IsMatch(identifier); + } + + /// + /// Validates a SQL identifier and throws an exception if invalid. + /// + /// The identifier to validate + /// The type of identifier (e.g., "table name", "column name") + /// Thrown when the identifier is invalid + public static void ValidateOrThrow(string? identifier, string identifierType = "identifier") + { + if (!IsValid(identifier)) + { + throw new ArgumentException( + $"Invalid {identifierType}: '{identifier}'. " + + $"Identifiers must start with a letter or underscore, " + + $"contain only letters, numbers, and underscores, " + + $"and be no longer than {MaxIdentifierLength} characters.", + nameof(identifier) + ); + } + } + + /// + /// Validates multiple identifiers and throws an exception if any are invalid. + /// + /// The identifiers to validate + /// The type of identifiers (e.g., "column names") + /// Thrown when any identifier is invalid + public static void ValidateAllOrThrow(IEnumerable identifiers, string identifierType = "identifiers") + { + foreach (var identifier in identifiers) + { + ValidateOrThrow(identifier, identifierType); + } + } + + /// + /// Filters a list of identifiers to only include valid ones. + /// + /// The identifiers to filter + /// A list of valid identifiers + public static List FilterValid(IEnumerable identifiers) + { + return identifiers.Where(IsValid).ToList(); + } +} diff --git a/util/Seeder/Migration/Utils/SecuritySanitizer.cs b/util/Seeder/Migration/Utils/SecuritySanitizer.cs index 38ee6356f2..58d75ff7e8 100644 --- a/util/Seeder/Migration/Utils/SecuritySanitizer.cs +++ b/util/Seeder/Migration/Utils/SecuritySanitizer.cs @@ -27,7 +27,8 @@ public static class SecuritySanitizer foreach (var (key, value) in configDict) { - if (SensitiveFields.Contains(key.ToLower())) + // Use case-insensitive comparison with proper culture handling + if (SensitiveFields.Any(field => field.Equals(key, StringComparison.OrdinalIgnoreCase))) { sanitized[key] = value != null ? MaskPassword(value.ToString() ?? string.Empty) : string.Empty; }