1
0
mirror of https://github.com/bitwarden/server synced 2025-12-24 20:23:21 +00:00

Added bulk copy

This commit is contained in:
Mark Kincaid
2025-10-29 15:40:49 -07:00
parent 71d5d3dd17
commit 8cf7327ca6
6 changed files with 671 additions and 2 deletions

View File

@@ -84,6 +84,30 @@ public interface IDatabaseImporter : IDisposable
/// <returns>True if foreign keys were enabled successfully, false otherwise.</returns>
bool EnableForeignKeys();
/// <summary>
/// Checks if this importer supports optimized bulk copy operations.
/// </summary>
/// <returns>True if bulk copy is supported and should be preferred over row-by-row import.</returns>
bool SupportsBulkCopy();
/// <summary>
/// Imports data into a table using database-specific bulk copy operations for optimal performance.
/// This method uses native bulk import mechanisms like PostgreSQL COPY, SQL Server SqlBulkCopy,
/// or multi-row INSERT statements for databases that support them.
/// </summary>
/// <param name="tableName">Name of the target table.</param>
/// <param name="columns">List of column names in the data.</param>
/// <param name="data">Data rows to import.</param>
/// <returns>True if bulk import was successful, false otherwise.</returns>
/// <remarks>
/// This method is significantly faster than ImportData() for large datasets (10-100x speedup).
/// If this method returns false, the caller should fall back to ImportData().
/// </remarks>
bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data);
/// <summary>
/// Tests the connection to the database.
/// </summary>

View File

@@ -412,6 +412,151 @@ public class MariaDbImporter(DatabaseConfig config, ILogger<MariaDbImporter> log
}).ToArray();
}
public bool SupportsBulkCopy()
{
return true; // MariaDB multi-row INSERT is optimized
}
public bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
return true;
}
try
{
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
return false;
}
// Filter columns
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
for (int i = 0; i < columns.Count; i++)
{
if (actualColumns.Contains(columns[i]))
{
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
}
if (validColumns.Count == 0)
{
_logger.LogError("No valid columns found for table {TableName}", tableName);
return false;
}
var filteredData = data.Select(row =>
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
).ToList();
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using multi-row INSERT", filteredData.Count, tableName);
// Use multi-row INSERT for better performance
// INSERT INTO table (col1, col2) VALUES (val1, val2), (val3, val4), ...
// MariaDB can handle up to max_allowed_packet size, we'll use 1000 rows per batch
const int rowsPerBatch = 1000;
var totalImported = 0;
for (int i = 0; i < filteredData.Count; i += rowsPerBatch)
{
var batch = filteredData.Skip(i).Take(rowsPerBatch).ToList();
using var transaction = _connection.BeginTransaction();
try
{
// Build multi-row INSERT statement
var quotedColumns = validColumns.Select(col => $"`{col}`").ToList();
var columnPart = $"INSERT INTO `{tableName}` ({string.Join(", ", quotedColumns)}) VALUES ";
var valueSets = new List<string>();
var allParameters = new List<(string name, object value)>();
var paramIndex = 0;
foreach (var row in batch)
{
var preparedRow = PrepareRowForInsert(row, validColumns);
var rowParams = new List<string>();
for (int p = 0; p < preparedRow.Length; p++)
{
var paramName = $"@p{paramIndex}";
rowParams.Add(paramName);
allParameters.Add((paramName, preparedRow[p] ?? DBNull.Value));
paramIndex++;
}
valueSets.Add($"({string.Join(", ", rowParams)})");
}
var fullInsertSql = columnPart + string.Join(", ", valueSets);
using var command = new MySqlCommand(fullInsertSql, _connection, transaction);
// Add all parameters
foreach (var (name, value) in allParameters)
{
if (value is string strValue)
{
var param = new MySqlConnector.MySqlParameter
{
ParameterName = name,
MySqlDbType = MySqlConnector.MySqlDbType.LongText,
Value = strValue,
Size = strValue.Length
};
command.Parameters.Add(param);
}
else
{
command.Parameters.AddWithValue(name, value);
}
}
command.ExecuteNonQuery();
transaction.Commit();
totalImported += batch.Count;
if (filteredData.Count > 1000)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
}
catch
{
transaction.Rollback();
throw;
}
}
_logger.LogInformation("Successfully bulk imported {TotalImported} rows into {TableName}", totalImported, tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message);
_logger.LogError("Stack trace: {StackTrace}", ex.StackTrace);
if (ex.InnerException != null)
{
_logger.LogError("Inner exception: {Message}", ex.InnerException.Message);
}
return false;
}
}
public bool TestConnection()
{
try

View File

@@ -1,4 +1,5 @@
using Npgsql;
using NpgsqlTypes;
using Bit.Seeder.Migration.Models;
using Microsoft.Extensions.Logging;
@@ -525,6 +526,236 @@ public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> l
}).ToArray();
}
public bool SupportsBulkCopy()
{
return true; // PostgreSQL COPY is highly optimized
}
public bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
return true;
}
try
{
// Get the actual table name with correct casing
var actualTableName = GetActualTableName(tableName);
if (actualTableName == null)
{
_logger.LogError("Table {TableName} not found in database", tableName);
return false;
}
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
return false;
}
// Get column types from the database
var columnTypes = GetTableColumnTypes(tableName);
// Filter columns - use case-insensitive comparison
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
var validColumnTypes = new List<string>();
// Create a case-insensitive lookup of actual columns
var actualColumnsLookup = actualColumns.ToDictionary(c => c, c => c, StringComparer.OrdinalIgnoreCase);
for (int i = 0; i < columns.Count; i++)
{
if (actualColumnsLookup.TryGetValue(columns[i], out var actualColumnName))
{
validColumnIndices.Add(i);
validColumns.Add(actualColumnName);
validColumnTypes.Add(columnTypes.GetValueOrDefault(actualColumnName, "text"));
}
else
{
_logger.LogDebug("Column '{Column}' from CSV not found in table {TableName}", columns[i], tableName);
}
}
if (validColumns.Count == 0)
{
_logger.LogError("No valid columns found for table {TableName}", tableName);
return false;
}
var filteredData = data.Select(row =>
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
).ToList();
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using PostgreSQL COPY", filteredData.Count, tableName);
// Use PostgreSQL's COPY command for binary import (fastest method)
var quotedColumns = validColumns.Select(col => $"\"{col}\"");
var copyCommand = $"COPY \"{actualTableName}\" ({string.Join(", ", quotedColumns)}) FROM STDIN (FORMAT BINARY)";
using var writer = _connection.BeginBinaryImport(copyCommand);
foreach (var row in filteredData)
{
writer.StartRow();
var preparedRow = PrepareRowForInsert(row, validColumns);
for (int i = 0; i < preparedRow.Length; i++)
{
var value = preparedRow[i];
if (value == null || value == DBNull.Value)
{
writer.WriteNull();
}
else
{
// Write with appropriate type based on column type
var colType = validColumnTypes[i];
WriteValueForCopy(writer, value, colType);
}
}
}
var rowsImported = writer.Complete();
_logger.LogInformation("Successfully bulk imported {RowsImported} rows into {TableName}", rowsImported, tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message);
_logger.LogError("Stack trace: {StackTrace}", ex.StackTrace);
if (ex.InnerException != null)
{
_logger.LogError("Inner exception: {Message}", ex.InnerException.Message);
}
return false;
}
}
private void WriteValueForCopy(Npgsql.NpgsqlBinaryImporter writer, object value, string columnType)
{
// Handle type-specific writing for PostgreSQL COPY
switch (columnType.ToLower())
{
case "uuid":
if (value is string strGuid && Guid.TryParse(strGuid, out var guid))
writer.Write(guid, NpgsqlDbType.Uuid);
else if (value is Guid g)
writer.Write(g, NpgsqlDbType.Uuid);
else
writer.Write(value.ToString()!, NpgsqlDbType.Uuid);
break;
case "boolean":
if (value is bool b)
writer.Write(b);
else if (value is string strBool)
writer.Write(strBool.Equals("true", StringComparison.OrdinalIgnoreCase) || strBool == "1");
else
writer.Write(Convert.ToBoolean(value));
break;
case "smallint":
writer.Write(Convert.ToInt16(value));
break;
case "integer":
writer.Write(Convert.ToInt32(value));
break;
case "bigint":
writer.Write(Convert.ToInt64(value));
break;
case "real":
writer.Write(Convert.ToSingle(value));
break;
case "double precision":
writer.Write(Convert.ToDouble(value));
break;
case "numeric":
case "decimal":
writer.Write(Convert.ToDecimal(value));
break;
case "timestamp without time zone":
case "timestamp":
if (value is DateTime dt)
{
// For timestamp without time zone, we can use the value as-is
// But if it's Unspecified, treat it as if it's in the local context
var timestampValue = dt.Kind == DateTimeKind.Unspecified
? DateTime.SpecifyKind(dt, DateTimeKind.Utc)
: dt;
writer.Write(timestampValue, NpgsqlDbType.Timestamp);
}
else if (value is string strDt && DateTime.TryParse(strDt, out var parsedDt))
{
var timestampValue = DateTime.SpecifyKind(parsedDt, DateTimeKind.Utc);
writer.Write(timestampValue, NpgsqlDbType.Timestamp);
}
else
writer.Write(value.ToString()!);
break;
case "timestamp with time zone":
case "timestamptz":
if (value is DateTime dtz)
{
// PostgreSQL timestamptz requires UTC DateTimes
var utcValue = dtz.Kind == DateTimeKind.Unspecified
? DateTime.SpecifyKind(dtz, DateTimeKind.Utc)
: dtz.Kind == DateTimeKind.Local
? dtz.ToUniversalTime()
: dtz;
writer.Write(utcValue, NpgsqlDbType.TimestampTz);
}
else if (value is string strDtz && DateTime.TryParse(strDtz, out var parsedDtz))
{
// Parsed DateTimes are Unspecified, treat as UTC
var utcValue = DateTime.SpecifyKind(parsedDtz, DateTimeKind.Utc);
writer.Write(utcValue, NpgsqlDbType.TimestampTz);
}
else
writer.Write(value.ToString()!);
break;
case "date":
if (value is DateTime date)
writer.Write(date, NpgsqlDbType.Date);
else if (value is string strDate && DateTime.TryParse(strDate, out var parsedDate))
writer.Write(parsedDate, NpgsqlDbType.Date);
else
writer.Write(value.ToString()!);
break;
case "bytea":
if (value is byte[] bytes)
writer.Write(bytes);
else
writer.Write(value.ToString()!);
break;
default:
// Text and all other types
writer.Write(value.ToString()!);
break;
}
}
public bool TestConnection()
{
try

View File

@@ -81,6 +81,37 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
private Dictionary<string, string> GetTableColumnTypes(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var query = @"
SELECT COLUMN_NAME, DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = @TableName";
using var command = new SqlCommand(query, _connection);
command.Parameters.AddWithValue("@TableName", tableName);
var columnTypes = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
using var reader = command.ExecuteReader();
while (reader.Read())
{
columnTypes[reader.GetString(0)] = reader.GetString(1);
}
return columnTypes;
}
catch (Exception ex)
{
_logger.LogError("Error getting column types for table {TableName}: {Message}", tableName, ex.Message);
return new Dictionary<string, string>();
}
}
public bool TableExists(string tableName)
{
if (_connection == null)
@@ -593,6 +624,18 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
return preparedRow;
}
private object[] PrepareRowForInsertWithTypes(object?[] row, List<string> columnTypes)
{
var preparedRow = new object[row.Length];
for (int i = 0; i < row.Length; i++)
{
preparedRow[i] = ConvertValueForSqlServerWithType(row[i], columnTypes[i]);
}
return preparedRow;
}
private object ConvertValueForSqlServer(object? value)
{
if (value == null || value == DBNull.Value)
@@ -644,6 +687,197 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
return value;
}
private object ConvertValueForSqlServerWithType(object? value, string columnType)
{
if (value == null || value == DBNull.Value)
return DBNull.Value;
// Handle string conversions
if (value is string strValue)
{
// Only convert truly empty strings to DBNull, not whitespace
// This preserves JSON strings and other data that might have whitespace
if (strValue.Length == 0)
return DBNull.Value;
// Handle GUID values - SqlBulkCopy requires actual Guid objects for UNIQUEIDENTIFIER columns
// But NOT for NVARCHAR columns that happen to contain GUID strings
if (columnType.Equals("uniqueidentifier", StringComparison.OrdinalIgnoreCase))
{
if (Guid.TryParse(strValue, out var guidValue))
{
return guidValue;
}
}
// Handle boolean-like values
if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase))
return 1;
if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase))
return 0;
// Handle datetime values - SQL Server DATETIME supports 3 decimal places
if (DateTimeHelper.IsLikelyIsoDateTime(strValue))
{
try
{
// Remove timezone if present
var datetimePart = strValue.Contains('+') || strValue.EndsWith('Z') || strValue.Contains('T')
? DateTimeHelper.RemoveTimezone(strValue) ?? strValue
: strValue;
// Handle microseconds - SQL Server DATETIME precision is 3.33ms, so truncate to 3 digits
if (datetimePart.Contains('.'))
{
var parts = datetimePart.Split('.');
if (parts.Length == 2 && parts[1].Length > 3)
{
datetimePart = $"{parts[0]}.{parts[1][..3]}";
}
}
return datetimePart;
}
catch
{
// If conversion fails, return original value
}
}
}
return value;
}
public bool SupportsBulkCopy()
{
return true; // SQL Server SqlBulkCopy is highly optimized
}
public bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
return true;
}
try
{
// Get actual table columns from SQL Server
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
return false;
}
// Filter columns and data
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
var missingColumns = new List<string>();
for (int i = 0; i < columns.Count; i++)
{
if (actualColumns.Contains(columns[i]))
{
validColumnIndices.Add(i);
validColumns.Add(columns[i]);
}
else
{
missingColumns.Add(columns[i]);
}
}
if (missingColumns.Count > 0)
{
_logger.LogWarning("Skipping columns that don't exist in {TableName}: {Columns}", tableName, string.Join(", ", missingColumns));
}
if (validColumns.Count == 0)
{
_logger.LogError("No valid columns found for table {TableName}", tableName);
return false;
}
// Get column types for proper data conversion
var columnTypes = GetTableColumnTypes(tableName);
var validColumnTypes = validColumns.Select(col =>
columnTypes.GetValueOrDefault(col, "nvarchar")).ToList();
// Filter data to only include valid columns
var filteredData = data.Select(row =>
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
).ToList();
// Check if table has identity columns
var identityColumns = GetIdentityColumns(tableName);
var identityColumnsInData = validColumns.Intersect(identityColumns).ToList();
var needsIdentityInsert = identityColumnsInData.Count > 0;
if (needsIdentityInsert)
{
_logger.LogInformation("Table {TableName} has identity columns in import data: {Columns}", tableName, string.Join(", ", identityColumnsInData));
}
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using SqlBulkCopy", filteredData.Count, tableName);
// Use SqlBulkCopy for high-performance import
// When importing identity columns, we need SqlBulkCopyOptions.KeepIdentity
var bulkCopyOptions = needsIdentityInsert
? SqlBulkCopyOptions.KeepIdentity
: SqlBulkCopyOptions.Default;
using var bulkCopy = new SqlBulkCopy(_connection, bulkCopyOptions, null)
{
DestinationTableName = $"[{tableName}]",
BatchSize = 10000,
BulkCopyTimeout = 600 // 10 minutes
};
// Map columns
foreach (var column in validColumns)
{
bulkCopy.ColumnMappings.Add(column, column);
}
// Create DataTable
var dataTable = new DataTable();
foreach (var column in validColumns)
{
dataTable.Columns.Add(column, typeof(object));
}
// Add rows with data type conversion based on actual column types
foreach (var row in filteredData)
{
var preparedRow = PrepareRowForInsertWithTypes(row, validColumnTypes);
dataTable.Rows.Add(preparedRow);
}
bulkCopy.WriteToServer(dataTable);
_logger.LogInformation("Successfully bulk imported {Count} rows into {TableName}", filteredData.Count, tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message);
_logger.LogError("Stack trace: {StackTrace}", ex.StackTrace);
if (ex.InnerException != null)
{
_logger.LogError("Inner exception: {Message}", ex.InnerException.Message);
}
return false;
}
}
public bool TestConnection()
{
try

View File

@@ -415,6 +415,23 @@ public class SqliteImporter(DatabaseConfig config, ILogger<SqliteImporter> logge
}).ToArray();
}
public bool SupportsBulkCopy()
{
// SQLite performs better with the original row-by-row INSERT approach
// Multi-row INSERT causes performance degradation for SQLite
return false;
}
public bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data)
{
// Not implemented for SQLite - use standard ImportData instead
_logger.LogWarning("Bulk copy not supported for SQLite, use standard import");
return false;
}
public bool TestConnection()
{
try

View File

@@ -302,8 +302,26 @@ public class CsvMigrationRecipe(MigrationConfig config, ILoggerFactory loggerFac
}
}
var effectiveBatchSize = batchSize ?? _config.BatchSize;
var success = importer.ImportData(destTableName, columns, data, effectiveBatchSize);
// Try bulk copy first for better performance, fall back to row-by-row if needed
bool success;
if (importer.SupportsBulkCopy())
{
_logger.LogInformation("Using optimized bulk copy for {TableName}", tableName);
success = importer.ImportDataBulk(destTableName, columns, data);
if (!success)
{
_logger.LogWarning("Bulk copy failed for {TableName}, falling back to standard import", tableName);
var effectiveBatchSize = batchSize ?? _config.BatchSize;
success = importer.ImportData(destTableName, columns, data, effectiveBatchSize);
}
}
else
{
_logger.LogInformation("Using standard import for {TableName}", tableName);
var effectiveBatchSize = batchSize ?? _config.BatchSize;
success = importer.ImportData(destTableName, columns, data, effectiveBatchSize);
}
if (success)
{