diff --git a/util/Seeder/Migration/Databases/SqlServerImporter.cs b/util/Seeder/Migration/Databases/SqlServerImporter.cs index 6636e7ca64..a7bcf77643 100644 --- a/util/Seeder/Migration/Databases/SqlServerImporter.cs +++ b/util/Seeder/Migration/Databases/SqlServerImporter.cs @@ -15,7 +15,7 @@ public class SqlServerImporter(DatabaseConfig config, ILogger private readonly string _username = config.Username; private readonly string _password = config.Password; private SqlConnection? _connection; - private List<(string Schema, string Table, string Constraint)> _disabledConstraints = []; + private const string _trackingTableName = "[dbo].[_MigrationDisabledConstraint]"; public bool Connect() { @@ -179,6 +179,75 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } + private bool DropTrackingTable() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var dropSql = $"DROP TABLE IF EXISTS {_trackingTableName}"; + using var command = new SqlCommand(dropSql, _connection); + command.ExecuteNonQuery(); + + _logger.LogDebug("Dropped tracking table {TrackingTableName}", _trackingTableName); + return true; + } + catch (Exception ex) + { + _logger.LogWarning("Error dropping tracking table: {Message}", ex.Message); + return false; + } + } + + private List<(string Schema, string Table, string Constraint)> GetConstraintsToReEnable() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + var constraints = new List<(string Schema, string Table, string Constraint)>(); + + try + { + // Check if tracking table exists + var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')"; + using var checkCommand = new SqlCommand(checkSql, _connection); + var tableExists = (int)checkCommand.ExecuteScalar()! > 0; + + if (!tableExists) + { + _logger.LogDebug("Tracking table does not exist, no constraints to re-enable"); + return constraints; + } + + // Get only constraints that we disabled (PreExistingDisabled = 0) + var querySql = $@" + SELECT SchemaName, TableName, ConstraintName + FROM {_trackingTableName} + WHERE PreExistingDisabled = 0"; + + using var command = new SqlCommand(querySql, _connection); + using var reader = command.ExecuteReader(); + + while (reader.Read()) + { + constraints.Add(( + reader.GetString(0), + reader.GetString(1), + reader.GetString(2) + )); + } + + _logger.LogDebug("Found {Count} constraints to re-enable from tracking table", constraints.Count); + } + catch (Exception ex) + { + _logger.LogWarning("Error reading tracking table: {Message}", ex.Message); + } + + return constraints; + } + public bool DisableForeignKeys() { if (_connection == null) @@ -188,50 +257,153 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { _logger.LogInformation("Disabling foreign key constraints for SQL Server"); - // Get all foreign key constraints - var query = @" - SELECT - OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name, - OBJECT_NAME(parent_object_id) AS table_name, - name AS constraint_name - FROM sys.foreign_keys - WHERE is_disabled = 0"; - - using var command = new SqlCommand(query, _connection); - using var reader = command.ExecuteReader(); - - var constraints = new List<(string Schema, string Table, string Constraint)>(); - while (reader.Read()) + // Check if tracking table already exists + var checkSql = "SELECT COUNT(*) FROM sys.tables WHERE name = '_MigrationDisabledConstraint' AND schema_id = SCHEMA_ID('dbo')"; + using (var checkCommand = new SqlCommand(checkSql, _connection)) { - constraints.Add(( - reader.GetString(0), - reader.GetString(1), - reader.GetString(2) - )); - } - reader.Close(); + var tableExists = (int)checkCommand.ExecuteScalar()! > 0; - // Disable each constraint - _disabledConstraints = []; - foreach (var (schema, table, constraint) in constraints) - { - try + if (tableExists) { - var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]"; - using var disableCommand = new SqlCommand(disableSql, _connection); - disableCommand.ExecuteNonQuery(); - - _disabledConstraints.Add((schema, table, constraint)); - _logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table); - } - catch (Exception ex) - { - _logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message); + // Table exists - this means we're resuming from an interrupted run + // Constraints are already disabled and tracked + _logger.LogInformation("Tracking table already exists - resuming from previous interrupted run"); + _logger.LogInformation("Foreign key constraints are already disabled"); + return true; } } - _logger.LogInformation("Disabled {Count} foreign key constraints", _disabledConstraints.Count); - return true; + // Table doesn't exist - this is a fresh run + // Create table and disable constraints in a transaction for atomicity + using var transaction = _connection.BeginTransaction(); + try + { + // Create tracking table + var createSql = $@" + CREATE TABLE {_trackingTableName} ( + SchemaName NVARCHAR(128) NOT NULL, + TableName NVARCHAR(128) NOT NULL, + ConstraintName NVARCHAR(128) NOT NULL, + PreExistingDisabled BIT NOT NULL, + DisabledAt DATETIME2 NOT NULL DEFAULT GETDATE() + )"; + + using (var createCommand = new SqlCommand(createSql, _connection, transaction)) + { + createCommand.ExecuteNonQuery(); + } + + _logger.LogDebug("Created tracking table {TrackingTableName}", _trackingTableName); + + // First, get all PRE-EXISTING disabled foreign key constraints + var preExistingQuery = @" + SELECT + OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name, + OBJECT_NAME(parent_object_id) AS table_name, + name AS constraint_name + FROM sys.foreign_keys + WHERE is_disabled = 1"; + + var preExistingConstraints = new List<(string Schema, string Table, string Constraint)>(); + using (var preCommand = new SqlCommand(preExistingQuery, _connection, transaction)) + using (var preReader = preCommand.ExecuteReader()) + { + while (preReader.Read()) + { + preExistingConstraints.Add(( + preReader.GetString(0), + preReader.GetString(1), + preReader.GetString(2) + )); + } + } + + // Store pre-existing disabled constraints + foreach (var (schema, table, constraint) in preExistingConstraints) + { + var insertSql = $@" + INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled) + VALUES (@Schema, @Table, @Constraint, 1)"; + + using var insertCommand = new SqlCommand(insertSql, _connection, transaction); + insertCommand.Parameters.AddWithValue("@Schema", schema); + insertCommand.Parameters.AddWithValue("@Table", table); + insertCommand.Parameters.AddWithValue("@Constraint", constraint); + insertCommand.ExecuteNonQuery(); + } + + if (preExistingConstraints.Count > 0) + { + _logger.LogInformation("Found {Count} pre-existing disabled constraints", preExistingConstraints.Count); + } + + // Now get all ENABLED foreign key constraints + var enabledQuery = @" + SELECT + OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name, + OBJECT_NAME(parent_object_id) AS table_name, + name AS constraint_name + FROM sys.foreign_keys + WHERE is_disabled = 0"; + + var constraints = new List<(string Schema, string Table, string Constraint)>(); + using (var command = new SqlCommand(enabledQuery, _connection, transaction)) + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + constraints.Add(( + reader.GetString(0), + reader.GetString(1), + reader.GetString(2) + )); + } + } + + // Disable each enabled constraint and track it + var disabledCount = 0; + foreach (var (schema, table, constraint) in constraints) + { + try + { + // Disable the constraint + var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]"; + using var disableCommand = new SqlCommand(disableSql, _connection, transaction); + disableCommand.ExecuteNonQuery(); + + // Store in tracking table with PreExistingDisabled = false + var insertSql = $@" + INSERT INTO {_trackingTableName} (SchemaName, TableName, ConstraintName, PreExistingDisabled) + VALUES (@Schema, @Table, @Constraint, 0)"; + + using var insertCommand = new SqlCommand(insertSql, _connection, transaction); + insertCommand.Parameters.AddWithValue("@Schema", schema); + insertCommand.Parameters.AddWithValue("@Table", table); + insertCommand.Parameters.AddWithValue("@Constraint", constraint); + insertCommand.ExecuteNonQuery(); + + disabledCount++; + _logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table); + } + catch (Exception ex) + { + _logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message); + } + } + + // Commit the transaction - this makes everything atomic + transaction.Commit(); + + _logger.LogInformation("Disabled {Count} foreign key constraints", disabledCount); + return true; + } + catch + { + // If anything fails, rollback the transaction + // This ensures the tracking table doesn't exist with incomplete data + transaction.Rollback(); + throw; + } } catch (Exception ex) { @@ -249,8 +421,19 @@ public class SqlServerImporter(DatabaseConfig config, ILogger { _logger.LogInformation("Re-enabling foreign key constraints for SQL Server"); + // Get constraints that we disabled (PreExistingDisabled = 0) from tracking table + var constraintsToReEnable = GetConstraintsToReEnable(); + + if (constraintsToReEnable.Count == 0) + { + _logger.LogInformation("No constraints to re-enable"); + // Still drop tracking table to clean up + DropTrackingTable(); + return true; + } + var enabledCount = 0; - foreach (var (schema, table, constraint) in _disabledConstraints) + foreach (var (schema, table, constraint) in constraintsToReEnable) { try { @@ -267,8 +450,10 @@ public class SqlServerImporter(DatabaseConfig config, ILogger } } - _logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, _disabledConstraints.Count); - _disabledConstraints.Clear(); + _logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, constraintsToReEnable.Count); + + // Drop tracking table to clean up + DropTrackingTable(); return true; }