mirror of
https://github.com/bitwarden/server
synced 2025-12-25 12:43:14 +00:00
672 lines
23 KiB
C#
672 lines
23 KiB
C#
using Microsoft.Data.SqlClient;
|
|
using Bit.Seeder.Migration.Models;
|
|
using Bit.Seeder.Migration.Utils;
|
|
using Microsoft.Extensions.Logging;
|
|
using System.Data;
|
|
|
|
namespace Bit.Seeder.Migration.Databases;
|
|
|
|
public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter> logger) : IDatabaseImporter
|
|
{
|
|
private readonly ILogger<SqlServerImporter> _logger = logger;
|
|
private readonly string _host = config.Host;
|
|
private readonly int _port = config.Port;
|
|
private readonly string _database = config.Database;
|
|
private readonly string _username = config.Username;
|
|
private readonly string _password = config.Password;
|
|
private SqlConnection? _connection;
|
|
private List<(string Schema, string Table, string Constraint)> _disabledConstraints = [];
|
|
|
|
public bool Connect()
|
|
{
|
|
try
|
|
{
|
|
var connectionString = $"Server={_host},{_port};Database={_database};" +
|
|
$"User Id={_username};Password={_password};" +
|
|
$"TrustServerCertificate=True;Connection Timeout=30;";
|
|
|
|
_connection = new SqlConnection(connectionString);
|
|
_connection.Open();
|
|
|
|
_logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Failed to connect to SQL Server: {Message}", ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public void Disconnect()
|
|
{
|
|
if (_connection != null)
|
|
{
|
|
_connection.Close();
|
|
_connection.Dispose();
|
|
_connection = null;
|
|
_logger.LogInformation("Disconnected from SQL Server");
|
|
}
|
|
}
|
|
|
|
public List<string> GetTableColumns(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = @"
|
|
SELECT COLUMN_NAME
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME = @TableName
|
|
ORDER BY ORDINAL_POSITION";
|
|
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.Parameters.AddWithValue("@TableName", tableName);
|
|
|
|
var columns = new List<string>();
|
|
using var reader = command.ExecuteReader();
|
|
while (reader.Read())
|
|
{
|
|
columns.Add(reader.GetString(0));
|
|
}
|
|
|
|
return columns;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message);
|
|
return [];
|
|
}
|
|
}
|
|
|
|
public bool TableExists(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = @"
|
|
SELECT COUNT(*)
|
|
FROM INFORMATION_SCHEMA.TABLES
|
|
WHERE TABLE_NAME = @TableName AND TABLE_TYPE = 'BASE TABLE'";
|
|
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.Parameters.AddWithValue("@TableName", tableName);
|
|
|
|
var count = (int)command.ExecuteScalar()!;
|
|
return count > 0;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public int GetTableRowCount(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = $"SELECT COUNT(*) FROM [{tableName}]";
|
|
using var command = new SqlCommand(query, _connection);
|
|
|
|
var count = (int)command.ExecuteScalar()!;
|
|
_logger.LogDebug("Row count for {TableName}: {Count}", tableName, count);
|
|
return count;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message);
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
public bool DropTable(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = $"DROP TABLE IF EXISTS [{tableName}]";
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.ExecuteNonQuery();
|
|
|
|
_logger.LogInformation("Dropped table {TableName}", tableName);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public bool DisableForeignKeys()
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
_logger.LogInformation("Disabling foreign key constraints for SQL Server");
|
|
|
|
// Get all foreign key constraints
|
|
var query = @"
|
|
SELECT
|
|
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
|
|
OBJECT_NAME(parent_object_id) AS table_name,
|
|
name AS constraint_name
|
|
FROM sys.foreign_keys
|
|
WHERE is_disabled = 0";
|
|
|
|
using var command = new SqlCommand(query, _connection);
|
|
using var reader = command.ExecuteReader();
|
|
|
|
var constraints = new List<(string Schema, string Table, string Constraint)>();
|
|
while (reader.Read())
|
|
{
|
|
constraints.Add((
|
|
reader.GetString(0),
|
|
reader.GetString(1),
|
|
reader.GetString(2)
|
|
));
|
|
}
|
|
reader.Close();
|
|
|
|
// Disable each constraint
|
|
_disabledConstraints = [];
|
|
foreach (var (schema, table, constraint) in constraints)
|
|
{
|
|
try
|
|
{
|
|
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
|
|
using var disableCommand = new SqlCommand(disableSql, _connection);
|
|
disableCommand.ExecuteNonQuery();
|
|
|
|
_disabledConstraints.Add((schema, table, constraint));
|
|
_logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message);
|
|
}
|
|
}
|
|
|
|
_logger.LogInformation("Disabled {Count} foreign key constraints", _disabledConstraints.Count);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public bool EnableForeignKeys()
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
_logger.LogInformation("Re-enabling foreign key constraints for SQL Server");
|
|
|
|
var enabledCount = 0;
|
|
foreach (var (schema, table, constraint) in _disabledConstraints)
|
|
{
|
|
try
|
|
{
|
|
var enableSql = $"ALTER TABLE [{schema}].[{table}] CHECK CONSTRAINT [{constraint}]";
|
|
using var command = new SqlCommand(enableSql, _connection);
|
|
command.ExecuteNonQuery();
|
|
|
|
enabledCount++;
|
|
_logger.LogDebug("Re-enabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogWarning("Could not re-enable constraint {Constraint}: {Message}", constraint, ex.Message);
|
|
}
|
|
}
|
|
|
|
_logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, _disabledConstraints.Count);
|
|
_disabledConstraints.Clear();
|
|
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public bool CreateTableFromSchema(
|
|
string tableName,
|
|
List<string> columns,
|
|
Dictionary<string, string> columnTypes,
|
|
List<string>? specialColumns = null)
|
|
{
|
|
specialColumns ??= [];
|
|
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
// Build column definitions
|
|
var sqlServerColumns = new List<string>();
|
|
foreach (var colName in columns)
|
|
{
|
|
var colType = columnTypes.GetValueOrDefault(colName, "NVARCHAR(MAX)");
|
|
|
|
// If it's a special JSON column, ensure it's a large text type
|
|
if (specialColumns.Contains(colName) &&
|
|
!colType.ToUpper().Contains("VARCHAR(MAX)") &&
|
|
!colType.ToUpper().Contains("TEXT"))
|
|
{
|
|
colType = "NVARCHAR(MAX)";
|
|
}
|
|
|
|
sqlServerColumns.Add($"[{colName}] {colType}");
|
|
}
|
|
|
|
// Build CREATE TABLE statement
|
|
var createSql = $@"
|
|
CREATE TABLE [{tableName}] (
|
|
{string.Join(",\n ", sqlServerColumns)}
|
|
)";
|
|
|
|
_logger.LogInformation("Creating table {TableName} in SQL Server", tableName);
|
|
_logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql);
|
|
|
|
using var command = new SqlCommand(createSql, _connection);
|
|
command.ExecuteNonQuery();
|
|
|
|
_logger.LogInformation("Successfully created table {TableName}", tableName);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public List<string> GetIdentityColumns(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = @"
|
|
SELECT COLUMN_NAME
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME = @TableName
|
|
AND COLUMNPROPERTY(OBJECT_ID(TABLE_SCHEMA + '.' + TABLE_NAME), COLUMN_NAME, 'IsIdentity') = 1";
|
|
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.Parameters.AddWithValue("@TableName", tableName);
|
|
|
|
var columns = new List<string>();
|
|
using var reader = command.ExecuteReader();
|
|
while (reader.Read())
|
|
{
|
|
columns.Add(reader.GetString(0));
|
|
}
|
|
|
|
return columns;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error getting identity columns for table {TableName}: {Message}", tableName, ex.Message);
|
|
return [];
|
|
}
|
|
}
|
|
|
|
public bool EnableIdentityInsert(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = $"SET IDENTITY_INSERT [{tableName}] ON";
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.ExecuteNonQuery();
|
|
|
|
_logger.LogDebug("Enabled IDENTITY_INSERT for {TableName}", tableName);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error enabling IDENTITY_INSERT for {TableName}: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public bool DisableIdentityInsert(string tableName)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
try
|
|
{
|
|
var query = $"SET IDENTITY_INSERT [{tableName}] OFF";
|
|
using var command = new SqlCommand(query, _connection);
|
|
command.ExecuteNonQuery();
|
|
|
|
_logger.LogDebug("Disabled IDENTITY_INSERT for {TableName}", tableName);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error disabling IDENTITY_INSERT for {TableName}: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public bool ImportData(
|
|
string tableName,
|
|
List<string> columns,
|
|
List<object[]> data,
|
|
int batchSize = 1000)
|
|
{
|
|
if (_connection == null)
|
|
throw new InvalidOperationException("Not connected to database");
|
|
|
|
if (data.Count == 0)
|
|
{
|
|
_logger.LogWarning("No data to import for table {TableName}", tableName);
|
|
return true;
|
|
}
|
|
|
|
try
|
|
{
|
|
// Get actual table columns from SQL Server
|
|
var actualColumns = GetTableColumns(tableName);
|
|
if (actualColumns.Count == 0)
|
|
{
|
|
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
|
|
return false;
|
|
}
|
|
|
|
// Filter columns and data
|
|
var validColumnIndices = new List<int>();
|
|
var validColumns = new List<string>();
|
|
var missingColumns = new List<string>();
|
|
|
|
for (int i = 0; i < columns.Count; i++)
|
|
{
|
|
if (actualColumns.Contains(columns[i]))
|
|
{
|
|
validColumnIndices.Add(i);
|
|
validColumns.Add(columns[i]);
|
|
}
|
|
else
|
|
{
|
|
missingColumns.Add(columns[i]);
|
|
}
|
|
}
|
|
|
|
if (missingColumns.Count > 0)
|
|
{
|
|
_logger.LogWarning("Skipping columns that don't exist in {TableName}: {Columns}", tableName, string.Join(", ", missingColumns));
|
|
}
|
|
|
|
if (validColumns.Count == 0)
|
|
{
|
|
_logger.LogError("No valid columns found for table {TableName}", tableName);
|
|
return false;
|
|
}
|
|
|
|
// Filter data to only include valid columns
|
|
var filteredData = data.Select(row =>
|
|
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
|
|
).ToList();
|
|
|
|
_logger.LogInformation("Valid columns for {TableName}: {Columns}", tableName, string.Join(", ", validColumns));
|
|
|
|
// Check if table has identity columns
|
|
var identityColumns = GetIdentityColumns(tableName);
|
|
var identityColumnsInData = validColumns.Intersect(identityColumns).ToList();
|
|
var needsIdentityInsert = identityColumnsInData.Count > 0;
|
|
|
|
if (needsIdentityInsert)
|
|
{
|
|
_logger.LogInformation("Table {TableName} has identity columns in import data: {Columns}", tableName, string.Join(", ", identityColumnsInData));
|
|
_logger.LogInformation("Enabling IDENTITY_INSERT to allow explicit identity values");
|
|
|
|
if (!EnableIdentityInsert(tableName))
|
|
{
|
|
_logger.LogError("Could not enable IDENTITY_INSERT for {TableName}", tableName);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check for existing data
|
|
var existingCount = GetTableRowCount(tableName);
|
|
if (existingCount > 0)
|
|
{
|
|
_logger.LogWarning("Table {TableName} already contains {ExistingCount} rows - potential for primary key conflicts", tableName, existingCount);
|
|
}
|
|
|
|
// Import using batch insert
|
|
var totalImported = FastBatchImport(tableName, validColumns, filteredData, batchSize);
|
|
|
|
_logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName);
|
|
|
|
// Disable IDENTITY_INSERT if it was enabled
|
|
if (needsIdentityInsert)
|
|
{
|
|
if (!DisableIdentityInsert(tableName))
|
|
{
|
|
_logger.LogWarning("Could not disable IDENTITY_INSERT for {TableName}", tableName);
|
|
}
|
|
}
|
|
|
|
// Validate that data was actually inserted
|
|
var actualCount = GetTableRowCount(tableName);
|
|
_logger.LogInformation("Post-import validation for {TableName}: imported {TotalImported}, table contains {ActualCount}", tableName, totalImported, actualCount);
|
|
|
|
if (actualCount < totalImported)
|
|
{
|
|
_logger.LogError("Import validation failed for {TableName}: expected at least {Expected}, found {Actual}", tableName, totalImported, actualCount);
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
private int UseSqlBulkCopy(string tableName, List<string> columns, List<object?[]> data)
|
|
{
|
|
try
|
|
{
|
|
using var bulkCopy = new SqlBulkCopy(_connection!)
|
|
{
|
|
DestinationTableName = $"[{tableName}]",
|
|
BatchSize = 10000,
|
|
BulkCopyTimeout = 600 // 10 minutes
|
|
};
|
|
|
|
// Map columns
|
|
foreach (var column in columns)
|
|
{
|
|
bulkCopy.ColumnMappings.Add(column, column);
|
|
}
|
|
|
|
// Create DataTable
|
|
var dataTable = new DataTable();
|
|
foreach (var column in columns)
|
|
{
|
|
dataTable.Columns.Add(column, typeof(object));
|
|
}
|
|
|
|
// Add rows with data type conversion
|
|
foreach (var row in data)
|
|
{
|
|
var preparedRow = PrepareRowForInsert(row, columns);
|
|
dataTable.Rows.Add(preparedRow);
|
|
}
|
|
|
|
_logger.LogInformation("Using SqlBulkCopy for {Count} rows", data.Count);
|
|
bulkCopy.WriteToServer(dataTable);
|
|
|
|
return data.Count;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogWarning("SqlBulkCopy failed: {Message}, falling back to batch insert", ex.Message);
|
|
return FastBatchImport(tableName, columns, data, 1000);
|
|
}
|
|
}
|
|
|
|
private int FastBatchImport(string tableName, List<string> columns, List<object?[]> data, int batchSize)
|
|
{
|
|
var quotedColumns = columns.Select(col => $"[{col}]").ToList();
|
|
var placeholders = string.Join(", ", columns.Select((_, i) => $"@p{i}"));
|
|
var insertSql = $"INSERT INTO [{tableName}] ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})";
|
|
|
|
var totalImported = 0;
|
|
|
|
for (int i = 0; i < data.Count; i += batchSize)
|
|
{
|
|
var batch = data.Skip(i).Take(batchSize).ToList();
|
|
|
|
using var transaction = _connection!.BeginTransaction();
|
|
try
|
|
{
|
|
foreach (var row in batch)
|
|
{
|
|
using var command = new SqlCommand(insertSql, _connection, transaction);
|
|
|
|
var preparedRow = PrepareRowForInsert(row, columns);
|
|
for (int p = 0; p < preparedRow.Length; p++)
|
|
{
|
|
command.Parameters.AddWithValue($"@p{p}", preparedRow[p] ?? DBNull.Value);
|
|
}
|
|
|
|
command.ExecuteNonQuery();
|
|
}
|
|
|
|
transaction.Commit();
|
|
totalImported += batch.Count;
|
|
|
|
if (data.Count > 1000)
|
|
{
|
|
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{DataCount} total, {Percentage:F1}%)", batch.Count, totalImported, data.Count, (totalImported / (double)data.Count * 100));
|
|
}
|
|
}
|
|
catch
|
|
{
|
|
transaction.Rollback();
|
|
throw;
|
|
}
|
|
}
|
|
|
|
return totalImported;
|
|
}
|
|
|
|
private object[] PrepareRowForInsert(object?[] row, List<string> columns)
|
|
{
|
|
var preparedRow = new object[row.Length];
|
|
|
|
for (int i = 0; i < row.Length; i++)
|
|
{
|
|
preparedRow[i] = ConvertValueForSqlServer(row[i]);
|
|
}
|
|
|
|
return preparedRow;
|
|
}
|
|
|
|
private object ConvertValueForSqlServer(object? value)
|
|
{
|
|
if (value == null || value == DBNull.Value)
|
|
return DBNull.Value;
|
|
|
|
// Handle string conversions
|
|
if (value is string strValue)
|
|
{
|
|
// Only convert truly empty strings to DBNull, not whitespace
|
|
// This preserves JSON strings and other data that might have whitespace
|
|
if (strValue.Length == 0)
|
|
return DBNull.Value;
|
|
|
|
// Handle boolean-like values
|
|
if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase))
|
|
return 1;
|
|
if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase))
|
|
return 0;
|
|
|
|
// Handle datetime values - SQL Server DATETIME supports 3 decimal places
|
|
if (DateTimeHelper.IsLikelyIsoDateTime(strValue))
|
|
{
|
|
try
|
|
{
|
|
// Remove timezone if present
|
|
var datetimePart = strValue.Contains('+') || strValue.EndsWith('Z') || strValue.Contains('T')
|
|
? DateTimeHelper.RemoveTimezone(strValue) ?? strValue
|
|
: strValue;
|
|
|
|
// Handle microseconds - SQL Server DATETIME precision is 3.33ms, so truncate to 3 digits
|
|
if (datetimePart.Contains('.'))
|
|
{
|
|
var parts = datetimePart.Split('.');
|
|
if (parts.Length == 2 && parts[1].Length > 3)
|
|
{
|
|
datetimePart = $"{parts[0]}.{parts[1][..3]}";
|
|
}
|
|
}
|
|
|
|
return datetimePart;
|
|
}
|
|
catch
|
|
{
|
|
// If conversion fails, return original value
|
|
}
|
|
}
|
|
}
|
|
|
|
return value;
|
|
}
|
|
|
|
public bool TestConnection()
|
|
{
|
|
try
|
|
{
|
|
if (Connect())
|
|
{
|
|
using var command = new SqlCommand("SELECT 1", _connection);
|
|
var result = command.ExecuteScalar();
|
|
Disconnect();
|
|
return result != null && (int)result == 1;
|
|
}
|
|
return false;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError("SQL Server import connection test failed: {Message}", ex.Message);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
Disconnect();
|
|
}
|
|
}
|