1
0
mirror of https://github.com/bitwarden/server synced 2025-12-24 12:13:17 +00:00

Ensure constraints are tracked correctly for SQL Server

This commit is contained in:
Mark Kincaid
2025-10-29 16:42:55 -07:00
parent 8cf7327ca6
commit 1c91178b25

View File

@@ -15,7 +15,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private SqlConnection? _connection;
private List<(string Schema, string Table, string Constraint)> _disabledConstraints = [];
private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]";
public bool Connect()
{
@@ -179,6 +179,75 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
private bool DropTrackingTable()
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var dropSql = $"DROP TABLE IF EXISTS {_trackingTableName}";
using var command = new SqlCommand(dropSql, _connection);
command.ExecuteNonQuery();
_logger.LogDebug("Dropped tracking table {TrackingTableName}", _trackingTableName);
return true;
}
catch (Exception ex)
{
_logger.LogWarning("Error dropping tracking table: {Message}", ex.Message);
return false;
}
}
private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable()
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
var constraints = new List<(string Schema, string Table, string Constraint)>();
try
{
// Check if tracking table exists
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
using var checkCommand = new SqlCommand(checkSql, _connection);
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
if (!tableExists)
{
_logger.LogDebug("Tracking table does not exist, no constraints to re-enable");
return constraints;
}
// Get only constraints that we disabled (PreExistingDisabled = 0)
var querySql = $@"
SELECT SchemaName, TableName, ConstraintName
FROM {_trackingTableName}
WHERE PreExistingDisabled = 0";
using var command = new SqlCommand(querySql, _connection);
using var reader = command.ExecuteReader();
while (reader.Read())
{
constraints.Add((
reader.GetString(0),
reader.GetString(1),
reader.GetString(2)
));
}
_logger.LogDebug("Found {Count} constraints to re-enable from tracking table", constraints.Count);
}
catch (Exception ex)
{
_logger.LogWarning("Error reading tracking table: {Message}", ex.Message);
}
return constraints;
}
public bool DisableForeignKeys()
{
if (_connection == null)
@@ -188,50 +257,153 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
_logger.LogInformation("Disabling foreign key constraints for SQL Server");
// Get all foreign key constraints
var query = @"
SELECT
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
OBJECT_NAME(parent_object_id) AS table_name,
name AS constraint_name
FROM sys.foreign_keys
WHERE is_disabled = 0";
using var command = new SqlCommand(query, _connection);
using var reader = command.ExecuteReader();
var constraints = new List<(string Schema, string Table, string Constraint)>();
while (reader.Read())
// Check if tracking table already exists
var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')";
using (var checkCommand = new SqlCommand(checkSql, _connection))
{
constraints.Add((
reader.GetString(0),
reader.GetString(1),
reader.GetString(2)
));
}
reader.Close();
var tableExists = (int)checkCommand.ExecuteScalar()! > 0;
// Disable each constraint
_disabledConstraints = [];
foreach (var (schema, table, constraint) in constraints)
{
try
if (tableExists)
{
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
using var disableCommand = new SqlCommand(disableSql, _connection);
disableCommand.ExecuteNonQuery();
_disabledConstraints.Add((schema, table, constraint));
_logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
}
catch (Exception ex)
{
_logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message);
// Table exists - this means we're resuming from an interrupted run
// Constraints are already disabled and tracked
_logger.LogInformation("Tracking table already exists - resuming from previous interrupted run");
_logger.LogInformation("Foreign key constraints are already disabled");
return true;
}
}
_logger.LogInformation("Disabled {Count} foreign key constraints", _disabledConstraints.Count);
return true;
// Table doesn't exist - this is a fresh run
// Create table and disable constraints in a transaction for atomicity
using var transaction = _connection.BeginTransaction();
try
{
// Create tracking table
var createSql = $@"
CREATE TABLE {_trackingTableName} (
SchemaName NVARCHAR(128) NOT NULL,
TableName NVARCHAR(128) NOT NULL,
ConstraintName NVARCHAR(128) NOT NULL,
PreExistingDisabled BIT NOT NULL,
DisabledAt DATETIME2 NOT NULL DEFAULT GETDATE()
)";
using (var createCommand = new SqlCommand(createSql, _connection, transaction))
{
createCommand.ExecuteNonQuery();
}
_logger.LogDebug("Created tracking table {TrackingTableName}", _trackingTableName);
// First, get all PRE-EXISTING disabled foreign key constraints
var preExistingQuery = @"
SELECT
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
OBJECT_NAME(parent_object_id) AS table_name,
name AS constraint_name
FROM sys.foreign_keys
WHERE is_disabled = 1";
var preExistingConstraints = new List<(string Schema, string Table, string Constraint)>();
using (var preCommand = new SqlCommand(preExistingQuery, _connection, transaction))
using (var preReader = preCommand.ExecuteReader())
{
while (preReader.Read())
{
preExistingConstraints.Add((
preReader.GetString(0),
preReader.GetString(1),
preReader.GetString(2)
));
}
}
// Store pre-existing disabled constraints
foreach (var (schema, table, constraint) in preExistingConstraints)
{
var insertSql = $@"
INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled)
VALUES (@Schema, @Table, @Constraint, 1)";
using var insertCommand = new SqlCommand(insertSql, _connection, transaction);
insertCommand.Parameters.AddWithValue("@Schema", schema);
insertCommand.Parameters.AddWithValue("@Table", table);
insertCommand.Parameters.AddWithValue("@Constraint", constraint);
insertCommand.ExecuteNonQuery();
}
if (preExistingConstraints.Count > 0)
{
_logger.LogInformation("Found {Count} pre-existing disabled constraints", preExistingConstraints.Count);
}
// Now get all ENABLED foreign key constraints
var enabledQuery = @"
SELECT
OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name,
OBJECT_NAME(parent_object_id) AS table_name,
name AS constraint_name
FROM sys.foreign_keys
WHERE is_disabled = 0";
var constraints = new List<(string Schema, string Table, string Constraint)>();
using (var command = new SqlCommand(enabledQuery, _connection, transaction))
using (var reader = command.ExecuteReader())
{
while (reader.Read())
{
constraints.Add((
reader.GetString(0),
reader.GetString(1),
reader.GetString(2)
));
}
}
// Disable each enabled constraint and track it
var disabledCount = 0;
foreach (var (schema, table, constraint) in constraints)
{
try
{
// Disable the constraint
var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]";
using var disableCommand = new SqlCommand(disableSql, _connection, transaction);
disableCommand.ExecuteNonQuery();
// Store in tracking table with PreExistingDisabled = false
var insertSql = $@"
INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled)
VALUES (@Schema, @Table, @Constraint, 0)";
using var insertCommand = new SqlCommand(insertSql, _connection, transaction);
insertCommand.Parameters.AddWithValue("@Schema", schema);
insertCommand.Parameters.AddWithValue("@Table", table);
insertCommand.Parameters.AddWithValue("@Constraint", constraint);
insertCommand.ExecuteNonQuery();
disabledCount++;
_logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table);
}
catch (Exception ex)
{
_logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message);
}
}
// Commit the transaction - this makes everything atomic
transaction.Commit();
_logger.LogInformation("Disabled {Count} foreign key constraints", disabledCount);
return true;
}
catch
{
// If anything fails, rollback the transaction
// This ensures the tracking table doesn't exist with incomplete data
transaction.Rollback();
throw;
}
}
catch (Exception ex)
{
@@ -249,8 +421,19 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
{
_logger.LogInformation("Re-enabling foreign key constraints for SQL Server");
// Get constraints that we disabled (PreExistingDisabled = 0) from tracking table
var constraintsToReEnable = GetConstraintsToReEnable();
if (constraintsToReEnable.Count == 0)
{
_logger.LogInformation("No constraints to re-enable");
// Still drop tracking table to clean up
DropTrackingTable();
return true;
}
var enabledCount = 0;
foreach (var (schema, table, constraint) in _disabledConstraints)
foreach (var (schema, table, constraint) in constraintsToReEnable)
{
try
{
@@ -267,8 +450,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger<SqlServerImporter>
}
}
_logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, _disabledConstraints.Count);
_disabledConstraints.Clear();
_logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, constraintsToReEnable.Count);
// Drop tracking table to clean up
DropTrackingTable();
return true;
}