mirror of
https://github.com/bitwarden/server
synced 2025-12-24 12:13:17 +00:00
Ensure constraints are tracked correctly for SQL Server
This commit is contained in:
@@ -15,7 +15,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
private readonly string _username = config.Username;
|
||||
private readonly string _password = config.Password;
|
||||
private SqlConnection? _connection;
|
||||
private List<(string Schema, string Table, string Constraint)> _disabledConstraints = [];
|
||||
private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]";
|
||||
|
||||
public bool Connect()
|
||||
{
|
||||
@@ -179,6 +179,75 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
private bool DropTrackingTable()
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
try
|
||||
{
|
||||
var dropSql = $"DROP TABLE IF EXISTS {_trackingTableName}";
|
||||
using var command = new SqlCommand(dropSql, _connection);
|
||||
command.ExecuteNonQuery();
|
||||
|
||||
_logger.LogDebug("Dropped tracking table {TrackingTableName}", _trackingTableName);
|
||||
return true;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning("Error dropping tracking table: {Message}", ex.Message);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable()
|
||||
{
|
||||
if (_connection == null)
|
||||
throw new InvalidOperationException("Not connected to database");
|
||||
|
||||
var constraints = new List<(string Schema, string Table, string Constraint)>();
|
||||
|
||||
try
|
||||
{
|
||||
// Check if tracking table exists
|
||||
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
|
||||
using var checkCommand = new SqlCommand(checkSql, _connection);
|
||||
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
|
||||
|
||||
if (!tableExists)
|
||||
{
|
||||
_logger.LogDebug("Tracking table does not exist, no constraints to re-enable");
|
||||
return constraints;
|
||||
}
|
||||
|
||||
// Get only constraints that we disabled (PreExistingDisabled = 0)
|
||||
var querySql = $@"
|
||||
SELECT SchemaName, TableName, ConstraintName
|
||||
FROM {_trackingTableName}
|
||||
WHERE PreExistingDisabled = 0";
|
||||
|
||||
using var command = new SqlCommand(querySql, _connection);
|
||||
using var reader = command.ExecuteReader();
|
||||
|
||||
while (reader.Read())
|
||||
{
|
||||
constraints.Add((
|
||||
reader.GetString(0),
|
||||
reader.GetString(1),
|
||||
reader.GetString(2)
|
||||
));
|
||||
}
|
||||
|
||||
_logger.LogDebug("Found {Count} constraints to re-enable from tracking table", constraints.Count);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning("Error reading tracking table: {Message}", ex.Message);
|
||||
}
|
||||
|
||||
return constraints;
|
||||
}
|
||||
|
||||
public bool DisableForeignKeys()
|
||||
{
|
||||
if (_connection == null)
|
||||
@@ -188,50 +257,153 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
_logger.LogInformation("Disabling foreign key constraints for SQL Server");
|
||||
|
||||
// Get all foreign key constraints
|
||||
var query = @"
|
||||
SELECT
|
||||
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
|
||||
OBJECT_NAME(parent_object_id) AS table_name,
|
||||
name AS constraint_name
|
||||
FROM sys.foreign_keys
|
||||
WHERE is_disabled = 0";
|
||||
|
||||
using var command = new SqlCommand(query, _connection);
|
||||
using var reader = command.ExecuteReader();
|
||||
|
||||
var constraints = new List<(string Schema, string Table, string Constraint)>();
|
||||
while (reader.Read())
|
||||
// Check if tracking table already exists
|
||||
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
|
||||
using (var checkCommand = new SqlCommand(checkSql, _connection))
|
||||
{
|
||||
constraints.Add((
|
||||
reader.GetString(0),
|
||||
reader.GetString(1),
|
||||
reader.GetString(2)
|
||||
));
|
||||
}
|
||||
reader.Close();
|
||||
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
|
||||
|
||||
// Disable each constraint
|
||||
_disabledConstraints = [];
|
||||
foreach (var (schema, table, constraint) in constraints)
|
||||
{
|
||||
try
|
||||
if (tableExists)
|
||||
{
|
||||
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
|
||||
using var disableCommand = new SqlCommand(disableSql, _connection);
|
||||
disableCommand.ExecuteNonQuery();
|
||||
|
||||
_disabledConstraints.Add((schema, table, constraint));
|
||||
_logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message);
|
||||
// Table exists - this means we're resuming from an interrupted run
|
||||
// Constraints are already disabled and tracked
|
||||
_logger.LogInformation("Tracking table already exists - resuming from previous interrupted run");
|
||||
_logger.LogInformation("Foreign key constraints are already disabled");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Disabled {Count} foreign key constraints", _disabledConstraints.Count);
|
||||
return true;
|
||||
// Table doesn't exist - this is a fresh run
|
||||
// Create table and disable constraints in a transaction for atomicity
|
||||
using var transaction = _connection.BeginTransaction();
|
||||
try
|
||||
{
|
||||
// Create tracking table
|
||||
var createSql = $@"
|
||||
CREATE TABLE {_trackingTableName} (
|
||||
SchemaName NVARCHAR(128) NOT NULL,
|
||||
TableName NVARCHAR(128) NOT NULL,
|
||||
ConstraintName NVARCHAR(128) NOT NULL,
|
||||
PreExistingDisabled BIT NOT NULL,
|
||||
DisabledAt DATETIME2 NOT NULL DEFAULT GETDATE()
|
||||
)";
|
||||
|
||||
using (var createCommand = new SqlCommand(createSql, _connection, transaction))
|
||||
{
|
||||
createCommand.ExecuteNonQuery();
|
||||
}
|
||||
|
||||
_logger.LogDebug("Created tracking table {TrackingTableName}", _trackingTableName);
|
||||
|
||||
// First, get all PRE-EXISTING disabled foreign key constraints
|
||||
var preExistingQuery = @"
|
||||
SELECT
|
||||
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
|
||||
OBJECT_NAME(parent_object_id) AS table_name,
|
||||
name AS constraint_name
|
||||
FROM sys.foreign_keys
|
||||
WHERE is_disabled = 1";
|
||||
|
||||
var preExistingConstraints = new List<(string Schema, string Table, string Constraint)>();
|
||||
using (var preCommand = new SqlCommand(preExistingQuery, _connection, transaction))
|
||||
using (var preReader = preCommand.ExecuteReader())
|
||||
{
|
||||
while (preReader.Read())
|
||||
{
|
||||
preExistingConstraints.Add((
|
||||
preReader.GetString(0),
|
||||
preReader.GetString(1),
|
||||
preReader.GetString(2)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Store pre-existing disabled constraints
|
||||
foreach (var (schema, table, constraint) in preExistingConstraints)
|
||||
{
|
||||
var insertSql = $@"
|
||||
INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled)
|
||||
VALUES (@Schema, @Table, @Constraint, 1)";
|
||||
|
||||
using var insertCommand = new SqlCommand(insertSql, _connection, transaction);
|
||||
insertCommand.Parameters.AddWithValue("@Schema", schema);
|
||||
insertCommand.Parameters.AddWithValue("@Table", table);
|
||||
insertCommand.Parameters.AddWithValue("@Constraint", constraint);
|
||||
insertCommand.ExecuteNonQuery();
|
||||
}
|
||||
|
||||
if (preExistingConstraints.Count > 0)
|
||||
{
|
||||
_logger.LogInformation("Found {Count} pre-existing disabled constraints", preExistingConstraints.Count);
|
||||
}
|
||||
|
||||
// Now get all ENABLED foreign key constraints
|
||||
var enabledQuery = @"
|
||||
SELECT
|
||||
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
|
||||
OBJECT_NAME(parent_object_id) AS table_name,
|
||||
name AS constraint_name
|
||||
FROM sys.foreign_keys
|
||||
WHERE is_disabled = 0";
|
||||
|
||||
var constraints = new List<(string Schema, string Table, string Constraint)>();
|
||||
using (var command = new SqlCommand(enabledQuery, _connection, transaction))
|
||||
using (var reader = command.ExecuteReader())
|
||||
{
|
||||
while (reader.Read())
|
||||
{
|
||||
constraints.Add((
|
||||
reader.GetString(0),
|
||||
reader.GetString(1),
|
||||
reader.GetString(2)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Disable each enabled constraint and track it
|
||||
var disabledCount = 0;
|
||||
foreach (var (schema, table, constraint) in constraints)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Disable the constraint
|
||||
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
|
||||
using var disableCommand = new SqlCommand(disableSql, _connection, transaction);
|
||||
disableCommand.ExecuteNonQuery();
|
||||
|
||||
// Store in tracking table with PreExistingDisabled = false
|
||||
var insertSql = $@"
|
||||
INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled)
|
||||
VALUES (@Schema, @Table, @Constraint, 0)";
|
||||
|
||||
using var insertCommand = new SqlCommand(insertSql, _connection, transaction);
|
||||
insertCommand.Parameters.AddWithValue("@Schema", schema);
|
||||
insertCommand.Parameters.AddWithValue("@Table", table);
|
||||
insertCommand.Parameters.AddWithValue("@Constraint", constraint);
|
||||
insertCommand.ExecuteNonQuery();
|
||||
|
||||
disabledCount++;
|
||||
_logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction - this makes everything atomic
|
||||
transaction.Commit();
|
||||
|
||||
_logger.LogInformation("Disabled {Count} foreign key constraints", disabledCount);
|
||||
return true;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// If anything fails, rollback the transaction
|
||||
// This ensures the tracking table doesn't exist with incomplete data
|
||||
transaction.Rollback();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -249,8 +421,19 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
{
|
||||
_logger.LogInformation("Re-enabling foreign key constraints for SQL Server");
|
||||
|
||||
// Get constraints that we disabled (PreExistingDisabled = 0) from tracking table
|
||||
var constraintsToReEnable = GetConstraintsToReEnable();
|
||||
|
||||
if (constraintsToReEnable.Count == 0)
|
||||
{
|
||||
_logger.LogInformation("No constraints to re-enable");
|
||||
// Still drop tracking table to clean up
|
||||
DropTrackingTable();
|
||||
return true;
|
||||
}
|
||||
|
||||
var enabledCount = 0;
|
||||
foreach (var (schema, table, constraint) in _disabledConstraints)
|
||||
foreach (var (schema, table, constraint) in constraintsToReEnable)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -267,8 +450,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, _disabledConstraints.Count);
|
||||
_disabledConstraints.Clear();
|
||||
_logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, constraintsToReEnable.Count);
|
||||
|
||||
// Drop tracking table to clean up
|
||||
DropTrackingTable();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user