From 8cf7327ca6ef930433082d7738980d308b313c5f Mon Sep 17 00:00:00 2001 From: Mark Kincaid Date: Wed, 29 Oct 2025 15:40:49 -0700 Subject: [PATCH] Added bulk copy --- .../Migration/Databases/IDatabaseImporter.cs | 24 ++ .../Migration/Databases/MariaDbImporter.cs | 145 +++++++++++ .../Migration/Databases/PostgresImporter.cs | 231 +++++++++++++++++ .../Migration/Databases/SqlServerImporter.cs | 234 ++++++++++++++++++ .../Migration/Databases/SqliteImporter.cs | 17 ++ util/Seeder/Recipes/CsvMigrationRecipe.cs | 22 +- 6 files changed, 671 insertions(+), 2 deletions(-) diff --git a/util/Seeder/Migration/Databases/IDatabaseImporter.cs b/util/Seeder/Migration/Databases/IDatabaseImporter.cs index 6da2da8a66..54035056e8 100644 --- a/util/Seeder/Migration/Databases/IDatabaseImporter.cs +++ b/util/Seeder/Migration/Databases/IDatabaseImporter.cs @@ -84,6 +84,30 @@ public interface IDatabaseImporter : IDisposable /// True if foreign keys were enabled successfully, false otherwise. bool EnableForeignKeys(); + /// + /// Checks if this importer supports optimized bulk copy operations. + /// + /// True if bulk copy is supported and should be preferred over row-by-row import. + bool SupportsBulkCopy(); + + /// + /// Imports data into a table using database-specific bulk copy operations for optimal performance. + /// This method uses native bulk import mechanisms like PostgreSQL COPY, SQL Server SqlBulkCopy, + /// or multi-row INSERT statements for databases that support them. + /// + /// Name of the target table. + /// List of column names in the data. + /// Data rows to import. + /// True if bulk import was successful, false otherwise. + /// + /// This method is significantly faster than ImportData() for large datasets (10-100x speedup). + /// If this method returns false, the caller should fall back to ImportData(). + /// + bool ImportDataBulk( + string tableName, + List columns, + List data); + /// /// Tests the connection to the database. /// diff --git a/util/Seeder/Migration/Databases/MariaDbImporter.cs b/util/Seeder/Migration/Databases/MariaDbImporter.cs index e3983f9f1f..3fe3244bec 100644 --- a/util/Seeder/Migration/Databases/MariaDbImporter.cs +++ b/util/Seeder/Migration/Databases/MariaDbImporter.cs @@ -412,6 +412,151 @@ public class MariaDbImporter(DatabaseConfig config, ILogger log }).ToArray(); } + public bool SupportsBulkCopy() + { + return true; // MariaDB multi-row INSERT is optimized + } + + public bool ImportDataBulk( + string tableName, + List columns, + List data) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Filter columns + var validColumnIndices = new List(); + var validColumns = new List(); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumns.Contains(columns[i])) + { + validColumnIndices.Add(i); + validColumns.Add(columns[i]); + } + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Bulk importing {Count} rows into {TableName} using multi-row INSERT", filteredData.Count, tableName); + + // Use multi-row INSERT for better performance + // INSERT INTO table (col1, col2) VALUES (val1, val2), (val3, val4), ... + // MariaDB can handle up to max_allowed_packet size, we'll use 1000 rows per batch + const int rowsPerBatch = 1000; + var totalImported = 0; + + for (int i = 0; i < filteredData.Count; i += rowsPerBatch) + { + var batch = filteredData.Skip(i).Take(rowsPerBatch).ToList(); + + using var transaction = _connection.BeginTransaction(); + try + { + // Build multi-row INSERT statement + var quotedColumns = validColumns.Select(col => $"`{col}`").ToList(); + var columnPart = $"INSERT INTO `{tableName}` ({string.Join(", ", quotedColumns)}) VALUES "; + + var valueSets = new List(); + var allParameters = new List<(string name, object value)>(); + var paramIndex = 0; + + foreach (var row in batch) + { + var preparedRow = PrepareRowForInsert(row, validColumns); + var rowParams = new List(); + + for (int p = 0; p < preparedRow.Length; p++) + { + var paramName = $"@p{paramIndex}"; + rowParams.Add(paramName); + allParameters.Add((paramName, preparedRow[p] ?? DBNull.Value)); + paramIndex++; + } + + valueSets.Add($"({string.Join(", ", rowParams)})"); + } + + var fullInsertSql = columnPart + string.Join(", ", valueSets); + + using var command = new MySqlCommand(fullInsertSql, _connection, transaction); + + // Add all parameters + foreach (var (name, value) in allParameters) + { + if (value is string strValue) + { + var param = new MySqlConnector.MySqlParameter + { + ParameterName = name, + MySqlDbType = MySqlConnector.MySqlDbType.LongText, + Value = strValue, + Size = strValue.Length + }; + command.Parameters.Add(param); + } + else + { + command.Parameters.AddWithValue(name, value); + } + } + + command.ExecuteNonQuery(); + transaction.Commit(); + totalImported += batch.Count; + + if (filteredData.Count > 1000) + { + _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); + } + } + catch + { + transaction.Rollback(); + throw; + } + } + + _logger.LogInformation("Successfully bulk imported {TotalImported} rows into {TableName}", totalImported, tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message); + _logger.LogError("Stack trace: {StackTrace}", ex.StackTrace); + if (ex.InnerException != null) + { + _logger.LogError("Inner exception: {Message}", ex.InnerException.Message); + } + return false; + } + } + public bool TestConnection() { try diff --git a/util/Seeder/Migration/Databases/PostgresImporter.cs b/util/Seeder/Migration/Databases/PostgresImporter.cs index 9d7d58857a..66e88cafa6 100644 --- a/util/Seeder/Migration/Databases/PostgresImporter.cs +++ b/util/Seeder/Migration/Databases/PostgresImporter.cs @@ -1,4 +1,5 @@ using Npgsql; +using NpgsqlTypes; using Bit.Seeder.Migration.Models; using Microsoft.Extensions.Logging; @@ -525,6 +526,236 @@ public class PostgresImporter(DatabaseConfig config, ILogger l }).ToArray(); } + public bool SupportsBulkCopy() + { + return true; // PostgreSQL COPY is highly optimized + } + + public bool ImportDataBulk( + string tableName, + List columns, + List data) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + // Get the actual table name with correct casing + var actualTableName = GetActualTableName(tableName); + if (actualTableName == null) + { + _logger.LogError("Table {TableName} not found in database", tableName); + return false; + } + + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Get column types from the database + var columnTypes = GetTableColumnTypes(tableName); + + // Filter columns - use case-insensitive comparison + var validColumnIndices = new List(); + var validColumns = new List(); + var validColumnTypes = new List(); + + // Create a case-insensitive lookup of actual columns + var actualColumnsLookup = actualColumns.ToDictionary(c => c, c => c, StringComparer.OrdinalIgnoreCase); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumnsLookup.TryGetValue(columns[i], out var actualColumnName)) + { + validColumnIndices.Add(i); + validColumns.Add(actualColumnName); + validColumnTypes.Add(columnTypes.GetValueOrDefault(actualColumnName, "text")); + } + else + { + _logger.LogDebug("Column '{Column}' from CSV not found in table {TableName}", columns[i], tableName); + } + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Bulk importing {Count} rows into {TableName} using PostgreSQL COPY", filteredData.Count, tableName); + + // Use PostgreSQL's COPY command for binary import (fastest method) + var quotedColumns = validColumns.Select(col => $"\"{col}\""); + var copyCommand = $"COPY \"{actualTableName}\" ({string.Join(", ", quotedColumns)}) FROM STDIN (FORMAT BINARY)"; + + using var writer = _connection.BeginBinaryImport(copyCommand); + + foreach (var row in filteredData) + { + writer.StartRow(); + + var preparedRow = PrepareRowForInsert(row, validColumns); + for (int i = 0; i < preparedRow.Length; i++) + { + var value = preparedRow[i]; + + if (value == null || value == DBNull.Value) + { + writer.WriteNull(); + } + else + { + // Write with appropriate type based on column type + var colType = validColumnTypes[i]; + WriteValueForCopy(writer, value, colType); + } + } + } + + var rowsImported = writer.Complete(); + _logger.LogInformation("Successfully bulk imported {RowsImported} rows into {TableName}", rowsImported, tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message); + _logger.LogError("Stack trace: {StackTrace}", ex.StackTrace); + if (ex.InnerException != null) + { + _logger.LogError("Inner exception: {Message}", ex.InnerException.Message); + } + return false; + } + } + + private void WriteValueForCopy(Npgsql.NpgsqlBinaryImporter writer, object value, string columnType) + { + // Handle type-specific writing for PostgreSQL COPY + switch (columnType.ToLower()) + { + case "uuid": + if (value is string strGuid && Guid.TryParse(strGuid, out var guid)) + writer.Write(guid, NpgsqlDbType.Uuid); + else if (value is Guid g) + writer.Write(g, NpgsqlDbType.Uuid); + else + writer.Write(value.ToString()!, NpgsqlDbType.Uuid); + break; + + case "boolean": + if (value is bool b) + writer.Write(b); + else if (value is string strBool) + writer.Write(strBool.Equals("true", StringComparison.OrdinalIgnoreCase) || strBool == "1"); + else + writer.Write(Convert.ToBoolean(value)); + break; + + case "smallint": + writer.Write(Convert.ToInt16(value)); + break; + + case "integer": + writer.Write(Convert.ToInt32(value)); + break; + + case "bigint": + writer.Write(Convert.ToInt64(value)); + break; + + case "real": + writer.Write(Convert.ToSingle(value)); + break; + + case "double precision": + writer.Write(Convert.ToDouble(value)); + break; + + case "numeric": + case "decimal": + writer.Write(Convert.ToDecimal(value)); + break; + + case "timestamp without time zone": + case "timestamp": + if (value is DateTime dt) + { + // For timestamp without time zone, we can use the value as-is + // But if it's Unspecified, treat it as if it's in the local context + var timestampValue = dt.Kind == DateTimeKind.Unspecified + ? DateTime.SpecifyKind(dt, DateTimeKind.Utc) + : dt; + writer.Write(timestampValue, NpgsqlDbType.Timestamp); + } + else if (value is string strDt && DateTime.TryParse(strDt, out var parsedDt)) + { + var timestampValue = DateTime.SpecifyKind(parsedDt, DateTimeKind.Utc); + writer.Write(timestampValue, NpgsqlDbType.Timestamp); + } + else + writer.Write(value.ToString()!); + break; + + case "timestamp with time zone": + case "timestamptz": + if (value is DateTime dtz) + { + // PostgreSQL timestamptz requires UTC DateTimes + var utcValue = dtz.Kind == DateTimeKind.Unspecified + ? DateTime.SpecifyKind(dtz, DateTimeKind.Utc) + : dtz.Kind == DateTimeKind.Local + ? dtz.ToUniversalTime() + : dtz; + writer.Write(utcValue, NpgsqlDbType.TimestampTz); + } + else if (value is string strDtz && DateTime.TryParse(strDtz, out var parsedDtz)) + { + // Parsed DateTimes are Unspecified, treat as UTC + var utcValue = DateTime.SpecifyKind(parsedDtz, DateTimeKind.Utc); + writer.Write(utcValue, NpgsqlDbType.TimestampTz); + } + else + writer.Write(value.ToString()!); + break; + + case "date": + if (value is DateTime date) + writer.Write(date, NpgsqlDbType.Date); + else if (value is string strDate && DateTime.TryParse(strDate, out var parsedDate)) + writer.Write(parsedDate, NpgsqlDbType.Date); + else + writer.Write(value.ToString()!); + break; + + case "bytea": + if (value is byte[] bytes) + writer.Write(bytes); + else + writer.Write(value.ToString()!); + break; + + default: + // Text and all other types + writer.Write(value.ToString()!); + break; + } + } + public bool TestConnection() { try diff --git a/util/Seeder/Migration/Databases/SqlServerImporter.cs b/util/Seeder/Migration/Databases/SqlServerImporter.cs index 215d8784d9..6636e7ca64 100644 --- a/util/Seeder/Migration/Databases/SqlServerImporter.cs +++ b/util/Seeder/Migration/Databases/SqlServerImporter.cs @@ -81,6 +81,37 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + private Dictionary GetTableColumnTypes(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT COLUMN_NAME, DATA_TYPE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = @TableName"; + + using var command = new SqlCommand(query, _connection); + command.Parameters.AddWithValue("@TableName", tableName); + + var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columnTypes[reader.GetString(0)] = reader.GetString(1); + } + + return columnTypes; + } + catch (Exception ex) + { + _logger.LogError("Error getting column types for table {TableName}: {Message}", tableName, ex.Message); + return new Dictionary(); + } + } + public bool TableExists(string tableName) { if (_connection == null) @@ -593,6 +624,18 @@ public class SqlServerImporter(DatabaseConfig config, ILogger return preparedRow; } + private object[] PrepareRowForInsertWithTypes(object?[] row, List columnTypes) + { + var preparedRow = new object[row.Length]; + + for (int i = 0; i < row.Length; i++) + { + preparedRow[i] = ConvertValueForSqlServerWithType(row[i], columnTypes[i]); + } + + return preparedRow; + } + private object ConvertValueForSqlServer(object? value) { if (value == null || value == DBNull.Value) @@ -644,6 +687,197 @@ public class SqlServerImporter(DatabaseConfig config, ILogger return value; } + private object ConvertValueForSqlServerWithType(object? value, string columnType) + { + if (value == null || value == DBNull.Value) + return DBNull.Value; + + // Handle string conversions + if (value is string strValue) + { + // Only convert truly empty strings to DBNull, not whitespace + // This preserves JSON strings and other data that might have whitespace + if (strValue.Length == 0) + return DBNull.Value; + + // Handle GUID values - SqlBulkCopy requires actual Guid objects for UNIQUEIDENTIFIER columns + // But NOT for NVARCHAR columns that happen to contain GUID strings + if (columnType.Equals("uniqueidentifier", StringComparison.OrdinalIgnoreCase)) + { + if (Guid.TryParse(strValue, out var guidValue)) + { + return guidValue; + } + } + + // Handle boolean-like values + if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase)) + return 1; + if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase)) + return 0; + + // Handle datetime values - SQL Server DATETIME supports 3 decimal places + if (DateTimeHelper.IsLikelyIsoDateTime(strValue)) + { + try + { + // Remove timezone if present + var datetimePart = strValue.Contains('+') || strValue.EndsWith('Z') || strValue.Contains('T') + ? DateTimeHelper.RemoveTimezone(strValue) ?? strValue + : strValue; + + // Handle microseconds - SQL Server DATETIME precision is 3.33ms, so truncate to 3 digits + if (datetimePart.Contains('.')) + { + var parts = datetimePart.Split('.'); + if (parts.Length == 2 && parts[1].Length > 3) + { + datetimePart = $"{parts[0]}.{parts[1][..3]}"; + } + } + + return datetimePart; + } + catch + { + // If conversion fails, return original value + } + } + } + + return value; + } + + public bool SupportsBulkCopy() + { + return true; // SQL Server SqlBulkCopy is highly optimized + } + + public bool ImportDataBulk( + string tableName, + List columns, + List data) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + // Get actual table columns from SQL Server + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Filter columns and data + var validColumnIndices = new List(); + var validColumns = new List(); + var missingColumns = new List(); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumns.Contains(columns[i])) + { + validColumnIndices.Add(i); + validColumns.Add(columns[i]); + } + else + { + missingColumns.Add(columns[i]); + } + } + + if (missingColumns.Count > 0) + { + _logger.LogWarning("Skipping columns that don't exist in {TableName}: {Columns}", tableName, string.Join(", ", missingColumns)); + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + // Get column types for proper data conversion + var columnTypes = GetTableColumnTypes(tableName); + var validColumnTypes = validColumns.Select(col => + columnTypes.GetValueOrDefault(col, "nvarchar")).ToList(); + + // Filter data to only include valid columns + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + // Check if table has identity columns + var identityColumns = GetIdentityColumns(tableName); + var identityColumnsInData = validColumns.Intersect(identityColumns).ToList(); + var needsIdentityInsert = identityColumnsInData.Count > 0; + + if (needsIdentityInsert) + { + _logger.LogInformation("Table {TableName} has identity columns in import data: {Columns}", tableName, string.Join(", ", identityColumnsInData)); + } + + _logger.LogInformation("Bulk importing {Count} rows into {TableName} using SqlBulkCopy", filteredData.Count, tableName); + + // Use SqlBulkCopy for high-performance import + // When importing identity columns, we need SqlBulkCopyOptions.KeepIdentity + var bulkCopyOptions = needsIdentityInsert + ? SqlBulkCopyOptions.KeepIdentity + : SqlBulkCopyOptions.Default; + + using var bulkCopy = new SqlBulkCopy(_connection, bulkCopyOptions, null) + { + DestinationTableName = $"[{tableName}]", + BatchSize = 10000, + BulkCopyTimeout = 600 // 10 minutes + }; + + // Map columns + foreach (var column in validColumns) + { + bulkCopy.ColumnMappings.Add(column, column); + } + + // Create DataTable + var dataTable = new DataTable(); + foreach (var column in validColumns) + { + dataTable.Columns.Add(column, typeof(object)); + } + + // Add rows with data type conversion based on actual column types + foreach (var row in filteredData) + { + var preparedRow = PrepareRowForInsertWithTypes(row, validColumnTypes); + dataTable.Rows.Add(preparedRow); + } + + bulkCopy.WriteToServer(dataTable); + _logger.LogInformation("Successfully bulk imported {Count} rows into {TableName}", filteredData.Count, tableName); + + return true; + } + catch (Exception ex) + { + _logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message); + _logger.LogError("Stack trace: {StackTrace}", ex.StackTrace); + if (ex.InnerException != null) + { + _logger.LogError("Inner exception: {Message}", ex.InnerException.Message); + } + return false; + } + } + public bool TestConnection() { try diff --git a/util/Seeder/Migration/Databases/SqliteImporter.cs b/util/Seeder/Migration/Databases/SqliteImporter.cs index fd4e13c845..7900f93431 100644 --- a/util/Seeder/Migration/Databases/SqliteImporter.cs +++ b/util/Seeder/Migration/Databases/SqliteImporter.cs @@ -415,6 +415,23 @@ public class SqliteImporter(DatabaseConfig config, ILogger logge }).ToArray(); } + public bool SupportsBulkCopy() + { + // SQLite performs better with the original row-by-row INSERT approach + // Multi-row INSERT causes performance degradation for SQLite + return false; + } + + public bool ImportDataBulk( + string tableName, + List columns, + List data) + { + // Not implemented for SQLite - use standard ImportData instead + _logger.LogWarning("Bulk copy not supported for SQLite, use standard import"); + return false; + } + public bool TestConnection() { try diff --git a/util/Seeder/Recipes/CsvMigrationRecipe.cs b/util/Seeder/Recipes/CsvMigrationRecipe.cs index 7cbfb0e646..ea66af5f74 100644 --- a/util/Seeder/Recipes/CsvMigrationRecipe.cs +++ b/util/Seeder/Recipes/CsvMigrationRecipe.cs @@ -302,8 +302,26 @@ public class CsvMigrationRecipe(MigrationConfig config, ILoggerFactory loggerFac } } - var effectiveBatchSize = batchSize ?? _config.BatchSize; - var success = importer.ImportData(destTableName, columns, data, effectiveBatchSize); + // Try bulk copy first for better performance, fall back to row-by-row if needed + bool success; + if (importer.SupportsBulkCopy()) + { + _logger.LogInformation("Using optimized bulk copy for {TableName}", tableName); + success = importer.ImportDataBulk(destTableName, columns, data); + + if (!success) + { + _logger.LogWarning("Bulk copy failed for {TableName}, falling back to standard import", tableName); + var effectiveBatchSize = batchSize ?? _config.BatchSize; + success = importer.ImportData(destTableName, columns, data, effectiveBatchSize); + } + } + else + { + _logger.LogInformation("Using standard import for {TableName}", tableName); + var effectiveBatchSize = batchSize ?? _config.BatchSize; + success = importer.ImportData(destTableName, columns, data, effectiveBatchSize); + } if (success) {