From e104fe6c5cf776e9a2700b7c309abee6aff3ca73 Mon Sep 17 00:00:00 2001 From: Mark Kincaid Date: Wed, 15 Oct 2025 07:45:22 -0700 Subject: [PATCH] Added CSV imports to DbSeeder utility --- .vscode/launch.json | 19 + .vscode/tasks.json | 13 + util/DbSeederUtility/.gitignore | 339 +++++++++ .../MigrationSettingsFactory.cs | 41 ++ util/DbSeederUtility/Program.cs | 185 ++++- util/Seeder/Migration/CsvHandler.cs | 375 ++++++++++ .../Migration/Databases/MariaDbImporter.cs | 439 ++++++++++++ .../Migration/Databases/PostgresImporter.cs | 552 ++++++++++++++ .../Migration/Databases/SqlServerExporter.cs | 380 ++++++++++ .../Migration/Databases/SqlServerImporter.cs | 671 ++++++++++++++++++ .../Migration/Databases/SqliteImporter.cs | 442 ++++++++++++ util/Seeder/Migration/Models/Config.cs | 52 ++ .../Seeder/Migration/Models/ReporterModels.cs | 123 ++++ .../Migration/Reporters/ExportReporter.cs | 302 ++++++++ .../Migration/Reporters/ImportReporter.cs | 308 ++++++++ .../Reporters/VerificationReporter.cs | 299 ++++++++ util/Seeder/Migration/SchemaMapper.cs | 209 ++++++ util/Seeder/Migration/TableFilter.cs | 209 ++++++ util/Seeder/Migration/Utils/DateTimeHelper.cs | 56 ++ .../Migration/Utils/SecuritySanitizer.cs | 70 ++ util/Seeder/Migration/Utils/SshTunnel.cs | 271 +++++++ util/Seeder/Recipes/CsvMigrationRecipe.cs | 545 ++++++++++++++ util/Seeder/Seeder.csproj | 9 + 23 files changed, 5907 insertions(+), 2 deletions(-) create mode 100644 util/DbSeederUtility/.gitignore create mode 100644 util/DbSeederUtility/MigrationSettingsFactory.cs create mode 100644 util/Seeder/Migration/CsvHandler.cs create mode 100644 util/Seeder/Migration/Databases/MariaDbImporter.cs create mode 100644 util/Seeder/Migration/Databases/PostgresImporter.cs create mode 100644 util/Seeder/Migration/Databases/SqlServerExporter.cs create mode 100644 util/Seeder/Migration/Databases/SqlServerImporter.cs create mode 100644 util/Seeder/Migration/Databases/SqliteImporter.cs create mode 100644 util/Seeder/Migration/Models/Config.cs create mode 100644 util/Seeder/Migration/Models/ReporterModels.cs create mode 100644 util/Seeder/Migration/Reporters/ExportReporter.cs create mode 100644 util/Seeder/Migration/Reporters/ImportReporter.cs create mode 100644 util/Seeder/Migration/Reporters/VerificationReporter.cs create mode 100644 util/Seeder/Migration/SchemaMapper.cs create mode 100644 util/Seeder/Migration/TableFilter.cs create mode 100644 util/Seeder/Migration/Utils/DateTimeHelper.cs create mode 100644 util/Seeder/Migration/Utils/SecuritySanitizer.cs create mode 100644 util/Seeder/Migration/Utils/SshTunnel.cs create mode 100644 util/Seeder/Recipes/CsvMigrationRecipe.cs diff --git a/.vscode/launch.json b/.vscode/launch.json index c407ba5604..59b6c633f9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -600,6 +600,25 @@ "type": "coreclr", "request": "attach", "processId": "${command:pickProcess}" + }, + { + "name": "DbSeeder Utility", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "buildDbSeeder", + "program": "${workspaceFolder}/util/DbSeederUtility/bin/Debug/net8.0/DbSeeder.dll", + "args": ["organization", "-n", "testorg", "-u", "100", "-d", "test.local"], + "cwd": "${workspaceFolder}/util/DbSeederUtility", + "stopAtEntry": false, + "console": "internalConsole", + "env": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "presentation": { + "hidden": false, + "group": "utilities", + "order": 1 + } } ], } diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 567f9b6e58..62cb3fb779 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -220,6 +220,19 @@ "isDefault": true } }, + { + "label": "buildDbSeeder", + "hide": true, + "command": "dotnet", + "type": "process", + "args": [ + "build", + "${workspaceFolder}/util/DbSeederUtility/DbSeederUtility.csproj", + "/property:GenerateFullPaths=true", + "/consoleloggerparameters:NoSummary" + ], + "problemMatcher": "$msCompile" + }, { "label": "test", "type": "shell", diff --git a/util/DbSeederUtility/.gitignore b/util/DbSeederUtility/.gitignore new file mode 100644 index 0000000000..d7aadb0211 --- /dev/null +++ b/util/DbSeederUtility/.gitignore @@ -0,0 +1,339 @@ +# Configuration files (may contain sensitive data) +config.yaml +.env + +# Exported data +exports/ +*.csv +*.db +*.sqlite +*.sqlite3 + +# Logs +logs/ + +## .NET / C# artifacts + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio cache/options +.vs/ +.vscode/ +*.suo +*.user +*.userosscache +*.sln.docstates + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# ReSharper +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JetBrains Rider +.idea/ +*.sln.iml + +# NuGet +*.nupkg +*.snupkg +**/packages/* +!**/packages/build/ +*.nuget.props +*.nuget.targets +project.lock.json +project.fragment.lock.json +artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings +PublishScripts/ + +# NuGet Packages +*.nupkg +*.snupkg +**/[Pp]ackages/* +!**/[Pp]ackages/build/ +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +*.suo +*.user +*.userosscache +*.sln.docstates + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml + +# IDE +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db \ No newline at end of file diff --git a/util/DbSeederUtility/MigrationSettingsFactory.cs b/util/DbSeederUtility/MigrationSettingsFactory.cs new file mode 100644 index 0000000000..0d7b178f9e --- /dev/null +++ b/util/DbSeederUtility/MigrationSettingsFactory.cs @@ -0,0 +1,41 @@ +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Configuration; + +namespace Bit.DbSeederUtility; + +public static class MigrationSettingsFactory +{ + private static MigrationConfig? _migrationConfig; + + public static MigrationConfig MigrationConfig + { + get { return _migrationConfig ??= LoadMigrationConfig(); } + } + + private static MigrationConfig LoadMigrationConfig() + { + Console.WriteLine("Loading migration configuration..."); + + var configBuilder = new ConfigurationBuilder() + .SetBasePath(Directory.GetCurrentDirectory()) + .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{Environment.GetEnvironmentVariable("ASPNETCORE_ENVIRONMENT") ?? "Production"}.json", optional: true, reloadOnChange: true) + .AddUserSecrets("bitwarden-Api") + .AddEnvironmentVariables(); + + var configuration = configBuilder.Build(); + var migrationSection = configuration.GetSection("migration"); + + var config = new MigrationConfig(); + migrationSection.Bind(config); + + // Log configuration status + Console.WriteLine($"Migration configuration loaded:"); + Console.WriteLine($" Source DB: {(config.Source != null ? "Configured" : "Not configured")}"); + Console.WriteLine($" Destinations: {config.Destinations.Count} configured"); + Console.WriteLine($" CSV Output Dir: {config.CsvSettings.OutputDir}"); + Console.WriteLine($" Excluded Tables: {config.ExcludeTables.Count}"); + + return config; + } +} diff --git a/util/DbSeederUtility/Program.cs b/util/DbSeederUtility/Program.cs index 2d75b31934..a754150002 100644 --- a/util/DbSeederUtility/Program.cs +++ b/util/DbSeederUtility/Program.cs @@ -1,7 +1,9 @@ using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Seeder.Migration; using Bit.Seeder.Recipes; using CommandDotNet; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; namespace Bit.DbSeederUtility; @@ -23,12 +25,10 @@ public class Program string domain ) { - // Create service provider with necessary services var services = new ServiceCollection(); ServiceCollectionExtension.ConfigureServices(services); var serviceProvider = services.BuildServiceProvider(); - // Get a scoped DB context using var scope = serviceProvider.CreateScope(); var scopedServices = scope.ServiceProvider; var db = scopedServices.GetRequiredService(); @@ -36,4 +36,185 @@ public class Program var recipe = new OrganizationWithUsersRecipe(db); recipe.Seed(name, users, domain); } + + [Command("discover", Description = "Discover and analyze tables in source database")] + public void Discover( + [Option("startssh", Description = "Start SSH tunnel before operation")] + bool startSsh = false + ) + { + var config = MigrationSettingsFactory.MigrationConfig; + var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var recipe = new CsvMigrationRecipe(config, loggerFactory); + + if (startSsh && !recipe.StartSshTunnel(force: true)) + { + Console.WriteLine("Failed to start SSH tunnel"); + return; + } + + var success = recipe.DiscoverAndAnalyzeTables(); + + if (startSsh) + { + recipe.StopSshTunnel(); + } + + if (!success) + { + Console.WriteLine("Discovery failed"); + } + } + + [Command("export", Description = "Export tables from source database to CSV files")] + public void Export( + [Option("include-tables", Description = "Comma-separated list of tables to include")] + string? includeTables = null, + [Option("exclude-tables", Description = "Comma-separated list of tables to exclude")] + string? excludeTables = null, + [Option("startssh", Description = "Start SSH tunnel before operation")] + bool startSsh = false + ) + { + var config = MigrationSettingsFactory.MigrationConfig; + var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var recipe = new CsvMigrationRecipe(config, loggerFactory); + + TableFilter? tableFilter = null; + var includeList = TableFilter.ParseTableList(includeTables); + var excludeList = TableFilter.ParseTableList(excludeTables); + + if (includeList.Count > 0 || excludeList.Count > 0) + { + tableFilter = new TableFilter( + includeList.Count > 0 ? includeList : null, + excludeList.Count > 0 ? excludeList : null, + null, + loggerFactory.CreateLogger()); + } + + if (startSsh && !recipe.StartSshTunnel(force: true)) + { + Console.WriteLine("Failed to start SSH tunnel"); + return; + } + + var success = recipe.ExportAllTables(tableFilter); + + if (startSsh) + { + recipe.StopSshTunnel(); + } + + if (!success) + { + Console.WriteLine("Export failed"); + } + } + + [Command("import", Description = "Import CSV files to destination database")] + public void Import( + [Operand(Description = "Database type (postgres, mariadb, sqlite, sqlserver)")] + string database, + [Option("create-tables", Description = "Create tables if they don't exist")] + bool createTables = false, + [Option("clear-existing", Description = "Clear existing data before import")] + bool clearExisting = false, + [Option("verify", Description = "Verify import after completion")] + bool verify = false, + [Option("include-tables", Description = "Comma-separated list of tables to include")] + string? includeTables = null, + [Option("exclude-tables", Description = "Comma-separated list of tables to exclude")] + string? excludeTables = null, + [Option("batch-size", Description = "Number of rows per batch")] + int? batchSize = null + ) + { + var config = MigrationSettingsFactory.MigrationConfig; + var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var recipe = new CsvMigrationRecipe(config, loggerFactory); + + TableFilter? tableFilter = null; + var includeList = TableFilter.ParseTableList(includeTables); + var excludeList = TableFilter.ParseTableList(excludeTables); + + if (includeList.Count > 0 || excludeList.Count > 0) + { + tableFilter = new TableFilter( + includeList.Count > 0 ? includeList : null, + excludeList.Count > 0 ? excludeList : null, + null, + loggerFactory.CreateLogger()); + } + + var success = recipe.ImportToDatabase(database, createTables, clearExisting, tableFilter, batchSize); + + if (verify && success) + { + Console.WriteLine("\nRunning verification..."); + var verifySuccess = recipe.VerifyImport(database, tableFilter); + if (!verifySuccess) + { + Console.WriteLine("Import succeeded but verification found issues"); + } + } + + if (!success) + { + Console.WriteLine("Import failed"); + } + } + + [Command("verify", Description = "Verify import by comparing CSV row counts with database row counts")] + public void Verify( + [Operand(Description = "Database type (postgres, mariadb, sqlite, sqlserver)")] + string database, + [Option("include-tables", Description = "Comma-separated list of tables to include")] + string? includeTables = null, + [Option("exclude-tables", Description = "Comma-separated list of tables to exclude")] + string? excludeTables = null + ) + { + var config = MigrationSettingsFactory.MigrationConfig; + var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var recipe = new CsvMigrationRecipe(config, loggerFactory); + + TableFilter? tableFilter = null; + var includeList = TableFilter.ParseTableList(includeTables); + var excludeList = TableFilter.ParseTableList(excludeTables); + + if (includeList.Count > 0 || excludeList.Count > 0) + { + tableFilter = new TableFilter( + includeList.Count > 0 ? includeList : null, + excludeList.Count > 0 ? excludeList : null, + null, + loggerFactory.CreateLogger()); + } + + var success = recipe.VerifyImport(database, tableFilter); + + if (!success) + { + Console.WriteLine("Verification failed"); + } + } + + [Command("test-connection", Description = "Test connection to a specific database")] + public void TestConnection( + [Operand(Description = "Database type (postgres, mariadb, sqlite, sqlserver)")] + string database + ) + { + var config = MigrationSettingsFactory.MigrationConfig; + var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var recipe = new CsvMigrationRecipe(config, loggerFactory); + + var success = recipe.TestConnection(database); + + if (!success) + { + Console.WriteLine($"Connection to {database} failed"); + } + } } diff --git a/util/Seeder/Migration/CsvHandler.cs b/util/Seeder/Migration/CsvHandler.cs new file mode 100644 index 0000000000..89d531fe56 --- /dev/null +++ b/util/Seeder/Migration/CsvHandler.cs @@ -0,0 +1,375 @@ +using CsvHelper; +using CsvHelper.Configuration; +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using System.Globalization; +using System.Text; + +namespace Bit.Seeder.Migration; + +public class CsvHandler(CsvSettings settings, ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly CsvSettings _settings = settings; + private readonly string _outputDir = settings.OutputDir; + private readonly string _delimiter = settings.Delimiter; + private readonly string _fallbackDelimiter = settings.FallbackDelimiter; + private readonly Encoding _encoding = new UTF8Encoding(false); + + public string ExportTableToCsv( + string tableName, + List columns, + List data, + List? specialColumns = null) + { + specialColumns ??= []; + var csvPath = Path.Combine(_outputDir, $"{tableName}.csv"); + + _logger.LogInformation("Exporting {TableName} to {CsvPath}", tableName, csvPath); + _logger.LogInformation("Special JSON columns: {Columns}", string.Join(", ", specialColumns)); + + try + { + // Ensure output directory exists + Directory.CreateDirectory(_outputDir); + + // Test if we can write with primary delimiter + var delimiterToUse = TestDelimiterCompatibility(data, columns, specialColumns); + + var config = new CsvConfiguration(CultureInfo.InvariantCulture) + { + Delimiter = delimiterToUse, + HasHeaderRecord = _settings.IncludeHeaders, + Encoding = _encoding, + ShouldQuote = _ => true // Always quote all fields (QUOTE_ALL) + }; + + using var writer = new StreamWriter(csvPath, false, _encoding); + using var csv = new CsvWriter(writer, config); + + // Write headers if requested + if (_settings.IncludeHeaders) + { + foreach (var column in columns) + { + csv.WriteField(column); + } + csv.NextRecord(); + } + + // Write data rows + var rowsWritten = 0; + foreach (var row in data) + { + var processedRow = ProcessRowForExport(row, columns, specialColumns); + foreach (var field in processedRow) + { + csv.WriteField(field); + } + csv.NextRecord(); + rowsWritten++; + } + + _logger.LogInformation("Successfully exported {RowsWritten} rows to {CsvPath}", rowsWritten, csvPath); + return csvPath; + } + catch (Exception ex) + { + _logger.LogError("Error exporting table {TableName}: {Message}", tableName, ex.Message); + throw; + } + } + + public (List Columns, List Data) ImportCsvToData( + string csvPath, + List? specialColumns = null) + { + specialColumns ??= []; + + if (!File.Exists(csvPath)) + { + throw new FileNotFoundException($"CSV file not found: {csvPath}"); + } + + _logger.LogInformation("Reading data from {CsvPath}", csvPath); + + try + { + // Detect delimiter + var delimiterUsed = DetectCsvDelimiter(csvPath); + _logger.LogDebug("Detected delimiter for {CsvPath}: '{Delimiter}' (ASCII: {Ascii})", csvPath, delimiterUsed, (int)delimiterUsed[0]); + + var config = new CsvConfiguration(CultureInfo.InvariantCulture) + { + Delimiter = delimiterUsed, + HasHeaderRecord = _settings.IncludeHeaders, + Encoding = _encoding, + BadDataFound = null, // Ignore bad data + TrimOptions = CsvHelper.Configuration.TrimOptions.None // Don't trim anything + }; + + using var reader = new StreamReader(csvPath, _encoding); + using var csv = new CsvReader(reader, config); + + var columns = new List(); + var dataRows = new List(); + + // Read headers if present + if (_settings.IncludeHeaders) + { + csv.Read(); + csv.ReadHeader(); + var rawColumns = csv.HeaderRecord?.ToList() ?? []; + _logger.LogDebug("Raw columns from CSV: {Columns}", string.Join(", ", rawColumns)); + // Remove surrounding quotes from column names if present + columns = rawColumns.Select(col => col.Trim('"')).ToList(); + _logger.LogDebug("Cleaned columns: {Columns}", string.Join(", ", columns)); + } + + // Read data rows + while (csv.Read()) + { + var row = new List(); + for (int i = 0; i < columns.Count; i++) + { + var field = csv.GetField(i) ?? string.Empty; + row.Add(field); + } + var processedRow = ProcessRowForImport(row.ToArray(), columns, specialColumns); + dataRows.Add(processedRow); + } + + _logger.LogInformation("Successfully read {RowCount} rows from {CsvPath}", dataRows.Count, csvPath); + return (columns, dataRows); + } + catch (Exception ex) + { + _logger.LogError("Error importing CSV {CsvPath}: {Message}", csvPath, ex.Message); + throw; + } + } + + private string TestDelimiterCompatibility( + List data, + List columns, + List specialColumns) + { + // Check a sample of rows for delimiter conflicts + var sampleSize = Math.Min(100, data.Count); + var specialColIndices = columns + .Select((col, idx) => new { col, idx }) + .Where(x => specialColumns.Contains(x.col)) + .Select(x => x.idx) + .ToList(); + + foreach (var row in data.Take(sampleSize)) + { + foreach (var colIdx in specialColIndices) + { + if (colIdx < row.Length && row[colIdx] != null) + { + var cellValue = row[colIdx]?.ToString() ?? string.Empty; + // If primary delimiter appears in JSON data, use fallback + if (cellValue.Contains(_delimiter) && !IsProperlyQuoted(cellValue)) + { + _logger.LogInformation( + "Primary delimiter '{Delimiter}' found in data, using fallback '{FallbackDelimiter}'", _delimiter, _fallbackDelimiter); + return _fallbackDelimiter; + } + } + } + } + + return _delimiter; + } + + private bool IsProperlyQuoted(string value) + { + return value.StartsWith("\"") && value.EndsWith("\""); + } + + private string DetectCsvDelimiter(string csvPath) + { + using var reader = new StreamReader(csvPath, _encoding); + // Read just the first line (header) for delimiter detection + var firstLine = reader.ReadLine(); + if (string.IsNullOrEmpty(firstLine)) + return ","; + + // Count delimiters outside of quoted fields + var commaCount = CountDelimitersOutsideQuotes(firstLine, ','); + var pipeCount = CountDelimitersOutsideQuotes(firstLine, '|'); + var tabCount = CountDelimitersOutsideQuotes(firstLine, '\t'); + + _logger.LogDebug("Delimiter counts - comma: {CommaCount}, pipe: {PipeCount}, tab: {TabCount}", commaCount, pipeCount, tabCount); + + if (pipeCount > commaCount && pipeCount > tabCount) + return "|"; + if (tabCount > commaCount && tabCount > pipeCount) + return "\t"; + + return ","; + } + + private int CountDelimitersOutsideQuotes(string line, char delimiter) + { + int count = 0; + bool inQuotes = false; + + for (int i = 0; i < line.Length; i++) + { + if (line[i] == '"') + { + // Handle escaped quotes (double quotes) + if (i + 1 < line.Length && line[i + 1] == '"') + { + i++; // Skip the next quote + } + else + { + inQuotes = !inQuotes; + } + } + else if (line[i] == delimiter && !inQuotes) + { + count++; + } + } + + return count; + } + + private object[] ProcessRowForExport( + object[] row, + List columns, + List specialColumns) + { + var processedRow = new object[row.Length]; + + for (int i = 0; i < row.Length; i++) + { + var colName = i < columns.Count ? columns[i] : $"col_{i}"; + + if (row[i] == null) + { + processedRow[i] = string.Empty; + } + else if (specialColumns.Contains(colName)) + { + // Handle JSON/encrypted data + processedRow[i] = PrepareJsonForCsv(row[i]); + } + else if (row[i] is DateTime dt) + { + // Format DateTime with full precision (microseconds) + // Format: yyyy-MM-dd HH:mm:ss.ffffff to match Python output + processedRow[i] = dt.ToString("yyyy-MM-dd HH:mm:ss.ffffff"); + } + else + { + // Handle regular data + processedRow[i] = row[i].ToString() ?? string.Empty; + } + } + + return processedRow; + } + + private object[] ProcessRowForImport( + object[] row, + List columns, + List specialColumns) + { + var processedRow = new object[row.Length]; + + for (int i = 0; i < row.Length; i++) + { + var colName = i < columns.Count ? columns[i] : $"col_{i}"; + var value = row[i]?.ToString() ?? string.Empty; + + if (string.IsNullOrEmpty(value)) + { + processedRow[i] = null!; + } + else if (specialColumns.Contains(colName)) + { + // Handle JSON/encrypted data + processedRow[i] = RestoreJsonFromCsv(value) ?? value; + } + else + { + // Handle regular data + processedRow[i] = value; + } + } + + return processedRow; + } + + private string PrepareJsonForCsv(object jsonData) + { + if (jsonData == null) + return string.Empty; + + var jsonStr = jsonData.ToString() ?? string.Empty; + + // Validate if it's valid JSON (for logging purposes) + try + { + JsonConvert.DeserializeObject(jsonStr); + _logger.LogDebug("Valid JSON data prepared for CSV export"); + } + catch (JsonException) + { + _logger.LogDebug("Non-JSON string data prepared for CSV export"); + } + + // Let CSV writer handle the escaping + return jsonStr; + } + + private string? RestoreJsonFromCsv(string csvData) + { + if (string.IsNullOrEmpty(csvData)) + return null; + + // Return as-is - the CSV reader should have handled unescaping + return csvData; + } + + public bool ValidateExport(int originalCount, string csvPath) + { + try + { + using var reader = new StreamReader(csvPath, _encoding); + var rowCount = 0L; + while (reader.ReadLine() != null) + { + rowCount++; + } + + // Subtract header row if present + if (_settings.IncludeHeaders) + { + rowCount--; + } + + if ((int)rowCount == originalCount) + { + _logger.LogInformation("Export validation passed: {RowCount} rows", rowCount); + return true; + } + else + { + _logger.LogError("Export validation failed: expected {Expected}, got {Actual}", originalCount, rowCount); + return false; + } + } + catch (Exception ex) + { + _logger.LogError("Error validating export: {Message}", ex.Message); + return false; + } + } +} diff --git a/util/Seeder/Migration/Databases/MariaDbImporter.cs b/util/Seeder/Migration/Databases/MariaDbImporter.cs new file mode 100644 index 0000000000..d7baaf7e9c --- /dev/null +++ b/util/Seeder/Migration/Databases/MariaDbImporter.cs @@ -0,0 +1,439 @@ +using MySqlConnector; +using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Utils; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Databases; + +public class MariaDbImporter(DatabaseConfig config, ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _host = config.Host; + private readonly int _port = config.Port > 0 ? config.Port : 3306; + private readonly string _database = config.Database; + private readonly string _username = config.Username; + private readonly string _password = config.Password; + private MySqlConnection? _connection; + + public bool Connect() + { + try + { + var connectionString = $"Server={_host};Port={_port};Database={_database};" + + $"Uid={_username};Pwd={_password};" + + $"ConnectionTimeout=30;CharSet=utf8mb4;AllowLoadLocalInfile=true;MaxPoolSize=100;"; + + _connection = new MySqlConnection(connectionString); + _connection.Open(); + + _logger.LogInformation("Connected to MariaDB: {Host}:{Port}/{Database}", _host, _port, _database); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to connect to MariaDB: {Message}", ex.Message); + return false; + } + } + + public void Disconnect() + { + if (_connection != null) + { + _connection.Close(); + _connection.Dispose(); + _connection = null; + _logger.LogInformation("Disconnected from MariaDB"); + } + } + + public bool CreateTableFromSchema( + string tableName, + List columns, + Dictionary columnTypes, + List? specialColumns = null) + { + specialColumns ??= []; + + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var mariaColumns = new List(); + foreach (var colName in columns) + { + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); + var mariaType = ConvertSqlServerTypeToMariaDB(sqlServerType, specialColumns.Contains(colName)); + mariaColumns.Add($"`{colName}` {mariaType}"); + } + + var createSql = $@" + CREATE TABLE IF NOT EXISTS `{tableName}` ( + {string.Join(",\n ", mariaColumns)} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci"; + + _logger.LogInformation("Creating table {TableName} in MariaDB", tableName); + _logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql); + + using var command = new MySqlCommand(createSql, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Successfully created table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public List GetTableColumns(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT column_name + FROM information_schema.columns + WHERE table_name = @tableName AND table_schema = @database + ORDER BY ordinal_position"; + + using var command = new MySqlCommand(query, _connection); + command.Parameters.AddWithValue("@tableName", tableName); + command.Parameters.AddWithValue("@database", _database); + + var columns = new List(); + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columns.Add(reader.GetString(0)); + } + + return columns; + } + catch (Exception ex) + { + _logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + public bool ImportData( + string tableName, + List columns, + List data, + int batchSize = 1000) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Filter columns + var validColumnIndices = new List(); + var validColumns = new List(); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumns.Contains(columns[i])) + { + validColumnIndices.Add(i); + validColumns.Add(columns[i]); + } + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Importing {Count} rows into {TableName}", filteredData.Count, tableName); + + // Build INSERT statement + var quotedColumns = validColumns.Select(col => $"`{col}`").ToList(); + var placeholders = string.Join(", ", Enumerable.Range(0, validColumns.Count).Select(i => $"@p{i}")); + var insertSql = $"INSERT INTO `{tableName}` ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})"; + + var totalImported = 0; + for (int i = 0; i < filteredData.Count; i += batchSize) + { + var batch = filteredData.Skip(i).Take(batchSize).ToList(); + + using var transaction = _connection.BeginTransaction(); + try + { + foreach (var row in batch) + { + using var command = new MySqlCommand(insertSql, _connection, transaction); + + var preparedRow = PrepareRowForInsert(row, validColumns); + for (int p = 0; p < preparedRow.Length; p++) + { + var value = preparedRow[p] ?? DBNull.Value; + + // For string values, explicitly set parameter type and size to avoid truncation + if (value is string strValue) + { + var param = new MySqlConnector.MySqlParameter + { + ParameterName = $"@p{p}", + MySqlDbType = MySqlConnector.MySqlDbType.LongText, + Value = strValue, + Size = strValue.Length + }; + command.Parameters.Add(param); + } + else + { + command.Parameters.AddWithValue($"@p{p}", value); + } + } + + command.ExecuteNonQuery(); + } + + transaction.Commit(); + totalImported += batch.Count; + + if (filteredData.Count > 1000) + { + _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); + } + } + catch + { + transaction.Rollback(); + throw; + } + } + + _logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool TableExists(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = @database AND table_name = @tableName"; + + using var command = new MySqlCommand(query, _connection); + command.Parameters.AddWithValue("@database", _database); + command.Parameters.AddWithValue("@tableName", tableName); + + var count = Convert.ToInt32(command.ExecuteScalar()); + return count > 0; + } + catch (Exception ex) + { + _logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message); + return false; + } + } + + public int GetTableRowCount(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"SELECT COUNT(*) FROM `{tableName}`"; + using var command = new MySqlCommand(query, _connection); + + return Convert.ToInt32(command.ExecuteScalar()); + } + catch (Exception ex) + { + _logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message); + return 0; + } + } + + public bool DropTable(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"DROP TABLE IF EXISTS `{tableName}`"; + using var command = new MySqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Dropped table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool DisableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Disabling foreign key constraints"); + var query = "SET FOREIGN_KEY_CHECKS = 0"; + using var command = new MySqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints disabled"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + public bool EnableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Re-enabling foreign key constraints"); + var query = "SET FOREIGN_KEY_CHECKS = 1"; + using var command = new MySqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints re-enabled"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + private string ConvertSqlServerTypeToMariaDB(string sqlServerType, bool isJsonColumn) + { + var baseType = sqlServerType.Replace(" NULL", "").Replace(" NOT NULL", "").Trim(); + var isNullable = !sqlServerType.Contains("NOT NULL"); + + if (isJsonColumn) + return "LONGTEXT" + (isNullable ? "" : " NOT NULL"); + + var mariaType = baseType.ToUpper() switch + { + var t when t.StartsWith("VARCHAR") => t.Contains("MAX") ? "LONGTEXT" : t.Replace("VARCHAR", "VARCHAR"), + var t when t.StartsWith("NVARCHAR") => "LONGTEXT", + "INT" or "INTEGER" => "INT", + "BIGINT" => "BIGINT", + "SMALLINT" => "SMALLINT", + "TINYINT" => "TINYINT", + "BIT" => "BOOLEAN", + var t when t.StartsWith("DECIMAL") => t.Replace("DECIMAL", "DECIMAL"), + "FLOAT" => "DOUBLE", + "REAL" => "FLOAT", + "DATETIME" or "DATETIME2" or "SMALLDATETIME" => "DATETIME", + "DATE" => "DATE", + "TIME" => "TIME", + "UNIQUEIDENTIFIER" => "CHAR(36)", + var t when t.StartsWith("VARBINARY") => "LONGBLOB", + "XML" => "LONGTEXT", + _ => "LONGTEXT" + }; + + return mariaType + (isNullable ? "" : " NOT NULL"); + } + + private object[] PrepareRowForInsert(object?[] row, List columns) + { + return row.Select(value => + { + if (value == null || value == DBNull.Value) + return DBNull.Value; + + if (value is string strValue) + { + // Only convert truly empty strings to DBNull, not whitespace + // This preserves JSON strings and other data that might have whitespace + if (strValue.Length == 0) + return DBNull.Value; + + if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase)) + return true; + if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase)) + return false; + + // Handle datetime with timezone + if ((strValue.Contains('+') || strValue.EndsWith('Z')) && + DateTimeHelper.IsLikelyIsoDateTime(strValue)) + { + return DateTimeHelper.RemoveTimezone(strValue) ?? strValue; + } + } + + return value; + }).ToArray(); + } + + public bool TestConnection() + { + try + { + if (Connect()) + { + using var command = new MySqlCommand("SELECT 1", _connection); + var result = command.ExecuteScalar(); + Disconnect(); + return result != null && Convert.ToInt32(result) == 1; + } + return false; + } + catch (Exception ex) + { + _logger.LogError("MariaDB connection test failed: {Message}", ex.Message); + return false; + } + } + + public void Dispose() + { + Disconnect(); + } +} diff --git a/util/Seeder/Migration/Databases/PostgresImporter.cs b/util/Seeder/Migration/Databases/PostgresImporter.cs new file mode 100644 index 0000000000..06c0e88fdd --- /dev/null +++ b/util/Seeder/Migration/Databases/PostgresImporter.cs @@ -0,0 +1,552 @@ +using Npgsql; +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Databases; + +public class PostgresImporter(DatabaseConfig config, ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _host = config.Host; + private readonly int _port = config.Port > 0 ? config.Port : 5432; + private readonly string _database = config.Database; + private readonly string _username = config.Username; + private readonly string _password = config.Password; + private NpgsqlConnection? _connection; + + public bool Connect() + { + try + { + var connectionString = $"Host={_host};Port={_port};Database={_database};" + + $"Username={_username};Password={_password};" + + $"Timeout=30;CommandTimeout=30;"; + + _connection = new NpgsqlConnection(connectionString); + _connection.Open(); + + _logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to connect to PostgreSQL: {Message}", ex.Message); + return false; + } + } + + public void Disconnect() + { + if (_connection != null) + { + _connection.Close(); + _connection.Dispose(); + _connection = null; + _logger.LogInformation("Disconnected from PostgreSQL"); + } + } + + public bool CreateTableFromSchema( + string tableName, + List columns, + Dictionary columnTypes, + List? specialColumns = null) + { + specialColumns ??= []; + + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + // Convert SQL Server types to PostgreSQL types + var pgColumns = new List(); + foreach (var colName in columns) + { + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); + var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName)); + pgColumns.Add($"\"{colName}\" {pgType}"); + } + + // Create tables with quoted identifiers to preserve case + var createSql = $@" + CREATE TABLE IF NOT EXISTS ""{tableName}"" ( + {string.Join(",\n ", pgColumns)} + )"; + + _logger.LogInformation("Creating table {TableName} in PostgreSQL", tableName); + _logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql); + + using var command = new NpgsqlCommand(createSql, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Successfully created table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + private string? GetActualTableName(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT table_name + FROM information_schema.tables + WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public' + LIMIT 1"; + + using var command = new NpgsqlCommand(query, _connection); + command.Parameters.AddWithValue("tableName", tableName); + + using var reader = command.ExecuteReader(); + if (reader.Read()) + { + return reader.GetString(0); + } + + return null; + } + catch (Exception ex) + { + _logger.LogError("Error getting actual table name for {TableName}: {Message}", tableName, ex.Message); + return null; + } + } + + public List GetTableColumns(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT column_name + FROM information_schema.columns + WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public' + ORDER BY ordinal_position"; + + using var command = new NpgsqlCommand(query, _connection); + command.Parameters.AddWithValue("tableName", tableName); + + var columns = new List(); + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columns.Add(reader.GetString(0)); + } + + return columns; + } + catch (Exception ex) + { + _logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + private Dictionary GetTableColumnTypes(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); + var query = @" + SELECT column_name, data_type + FROM information_schema.columns + WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'"; + + using var command = new NpgsqlCommand(query, _connection); + command.Parameters.AddWithValue("tableName", tableName); + + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columnTypes[reader.GetString(0)] = reader.GetString(1); + } + + return columnTypes; + } + catch (Exception ex) + { + _logger.LogError("Error getting column types for table {TableName}: {Message}", tableName, ex.Message); + return new Dictionary(); + } + } + + public bool ImportData( + string tableName, + List columns, + List data, + int batchSize = 1000) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + // Get the actual table name with correct casing + var actualTableName = GetActualTableName(tableName); + if (actualTableName == null) + { + _logger.LogError("Table {TableName} not found in database", tableName); + return false; + } + + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Get column types from the database + var columnTypes = GetTableColumnTypes(tableName); + + // Filter columns - use case-insensitive comparison + var validColumnIndices = new List(); + var validColumns = new List(); + var validColumnTypes = new List(); + + // Create a case-insensitive lookup of actual columns + var actualColumnsLookup = actualColumns.ToDictionary(c => c, c => c, StringComparer.OrdinalIgnoreCase); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumnsLookup.TryGetValue(columns[i], out var actualColumnName)) + { + validColumnIndices.Add(i); + validColumns.Add(actualColumnName); // Use the actual column name from DB + validColumnTypes.Add(columnTypes.GetValueOrDefault(actualColumnName, "text")); + } + else + { + _logger.LogDebug("Column '{Column}' from CSV not found in table {TableName}", columns[i], tableName); + } + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + _logger.LogError("CSV columns: {Columns}", string.Join(", ", columns)); + _logger.LogError("Table columns: {Columns}", string.Join(", ", actualColumns)); + return false; + } + + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Importing {Count} rows into {TableName}", filteredData.Count, tableName); + + // Build INSERT statement with explicit type casts for all types + var quotedColumns = validColumns.Select(col => $"\"{col}\"").ToList(); + var placeholders = validColumns.Select((col, idx) => + { + var paramNum = idx + 1; + var colType = validColumnTypes[idx]; + // Cast to appropriate type if needed - PostgreSQL requires explicit casts for text to other types + return colType switch + { + // UUID types + "uuid" => $"${paramNum}::uuid", + + // Timestamp types + "timestamp without time zone" => $"${paramNum}::timestamp", + "timestamp with time zone" => $"${paramNum}::timestamptz", + "date" => $"${paramNum}::date", + "time without time zone" => $"${paramNum}::time", + "time with time zone" => $"${paramNum}::timetz", + + // Integer types + "smallint" => $"${paramNum}::smallint", + "integer" => $"${paramNum}::integer", + "bigint" => $"${paramNum}::bigint", + + // Numeric types + "numeric" => $"${paramNum}::numeric", + "decimal" => $"${paramNum}::decimal", + "real" => $"${paramNum}::real", + "double precision" => $"${paramNum}::double precision", + + // Boolean type + "boolean" => $"${paramNum}::boolean", + + // Default - no cast needed for text types + _ => $"${paramNum}" + }; + }); + var insertSql = $"INSERT INTO \"{actualTableName}\" ({string.Join(", ", quotedColumns)}) VALUES ({string.Join(", ", placeholders)})"; + + var totalImported = 0; + for (int i = 0; i < filteredData.Count; i += batchSize) + { + var batch = filteredData.Skip(i).Take(batchSize).ToList(); + + using var transaction = _connection.BeginTransaction(); + try + { + foreach (var row in batch) + { + using var command = new NpgsqlCommand(insertSql, _connection, transaction); + + var preparedRow = PrepareRowForInsert(row, validColumns); + for (int p = 0; p < preparedRow.Length; p++) + { + command.Parameters.AddWithValue(preparedRow[p] ?? DBNull.Value); + } + + command.ExecuteNonQuery(); + } + + transaction.Commit(); + totalImported += batch.Count; + + if (filteredData.Count > 1000) + { + _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); + } + } + catch + { + transaction.Rollback(); + throw; + } + } + + _logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message); + _logger.LogError("Stack trace: {StackTrace}", ex.StackTrace); + if (ex.InnerException != null) + { + _logger.LogError("Inner exception: {Message}", ex.InnerException.Message); + } + return false; + } + } + + public bool TableExists(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT EXISTS ( + SELECT 1 FROM information_schema.tables + WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public' + )"; + + using var command = new NpgsqlCommand(query, _connection); + command.Parameters.AddWithValue("tableName", tableName); + + return (bool)command.ExecuteScalar()!; + } + catch (Exception ex) + { + _logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message); + return false; + } + } + + public int GetTableRowCount(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var actualTableName = GetActualTableName(tableName); + if (actualTableName == null) + { + _logger.LogError("Table {TableName} not found in database", tableName); + return 0; + } + + var query = $"SELECT COUNT(*) FROM \"{actualTableName}\""; + using var command = new NpgsqlCommand(query, _connection); + + return Convert.ToInt32(command.ExecuteScalar()); + } + catch (Exception ex) + { + _logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message); + return 0; + } + } + + public bool DropTable(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var actualTableName = GetActualTableName(tableName); + if (actualTableName == null) + { + _logger.LogWarning("Table {TableName} not found, skipping drop", tableName); + return true; + } + + var query = $"DROP TABLE IF EXISTS \"{actualTableName}\" CASCADE"; + using var command = new NpgsqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Dropped table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool DisableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Disabling foreign key constraints"); + var query = "SET session_replication_role = replica;"; + using var command = new NpgsqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints deferred"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + public bool EnableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Re-enabling foreign key constraints"); + var query = "SET session_replication_role = DEFAULT;"; + using var command = new NpgsqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints re-enabled"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + private string ConvertSqlServerTypeToPostgreSQL(string sqlServerType, bool isJsonColumn) + { + var baseType = sqlServerType.Replace(" NULL", "").Replace(" NOT NULL", "").Trim(); + var isNullable = !sqlServerType.Contains("NOT NULL"); + + if (isJsonColumn) + return "TEXT" + (isNullable ? "" : " NOT NULL"); + + var pgType = baseType.ToUpper() switch + { + var t when t.StartsWith("VARCHAR") => t.Contains("MAX") ? "TEXT" : t.Replace("VARCHAR", "VARCHAR"), + var t when t.StartsWith("NVARCHAR") => "TEXT", + "INT" or "INTEGER" => "INTEGER", + "BIGINT" => "BIGINT", + "SMALLINT" => "SMALLINT", + "TINYINT" => "SMALLINT", + "BIT" => "BOOLEAN", + var t when t.StartsWith("DECIMAL") => t.Replace("DECIMAL", "DECIMAL"), + "FLOAT" => "DOUBLE PRECISION", + "REAL" => "REAL", + "DATETIME" or "DATETIME2" or "SMALLDATETIME" => "TIMESTAMP", + "DATE" => "DATE", + "TIME" => "TIME", + "DATETIMEOFFSET" => "TIMESTAMPTZ", + "UNIQUEIDENTIFIER" => "UUID", + var t when t.StartsWith("VARBINARY") => "BYTEA", + "XML" => "XML", + _ => "TEXT" + }; + + return pgType + (isNullable ? "" : " NOT NULL"); + } + + private object[] PrepareRowForInsert(object?[] row, List columns) + { + return row.Select(value => + { + if (value == null || value == DBNull.Value) + return DBNull.Value; + + if (value is string strValue) + { + // Only convert truly empty strings to DBNull, not whitespace + // This preserves JSON strings and other data that might have whitespace + if (strValue.Length == 0) + return DBNull.Value; + + if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase)) + return true; + if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase)) + return false; + } + + return value; + }).ToArray(); + } + + public bool TestConnection() + { + try + { + if (Connect()) + { + using var command = new NpgsqlCommand("SELECT 1", _connection); + var result = command.ExecuteScalar(); + Disconnect(); + return result != null && (int)result == 1; + } + return false; + } + catch (Exception ex) + { + _logger.LogError("PostgreSQL connection test failed: {Message}", ex.Message); + return false; + } + } + + public void Dispose() + { + Disconnect(); + } +} diff --git a/util/Seeder/Migration/Databases/SqlServerExporter.cs b/util/Seeder/Migration/Databases/SqlServerExporter.cs new file mode 100644 index 0000000000..84ea35a224 --- /dev/null +++ b/util/Seeder/Migration/Databases/SqlServerExporter.cs @@ -0,0 +1,380 @@ +using Microsoft.Data.SqlClient; +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Databases; + +public class SqlServerExporter(DatabaseConfig config, ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _host = config.Host; + private readonly int _port = config.Port; + private readonly string _database = config.Database; + private readonly string _username = config.Username; + private readonly string _password = config.Password; + private SqlConnection? _connection; + + public bool Connect() + { + try + { + var connectionString = $"Server={_host},{_port};Database={_database};" + + $"User Id={_username};Password={_password};" + + $"TrustServerCertificate=True;Connection Timeout=30;"; + + _connection = new SqlConnection(connectionString); + _connection.Open(); + + _logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to connect to SQL Server: {Message}", ex.Message); + return false; + } + } + + public void Disconnect() + { + if (_connection != null) + { + _connection.Close(); + _connection.Dispose(); + _connection = null; + _logger.LogInformation("Disconnected from SQL Server"); + } + } + + public List DiscoverTables(bool excludeSystemTables = true) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE'"; + + if (excludeSystemTables) + { + query += @" + AND TABLE_SCHEMA = 'dbo' + AND TABLE_NAME NOT IN ('sysdiagrams', '__EFMigrationsHistory')"; + } + + query += " ORDER BY TABLE_NAME"; + + using var command = new SqlCommand(query, _connection); + using var reader = command.ExecuteReader(); + + var tables = new List(); + while (reader.Read()) + { + tables.Add(reader.GetString(0)); + } + + _logger.LogInformation("Discovered {Count} tables: {Tables}", tables.Count, string.Join(", ", tables)); + return tables; + } + catch (Exception ex) + { + _logger.LogError("Error discovering tables: {Message}", ex.Message); + throw; + } + } + + public TableInfo GetTableInfo(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + // Get column information + var columnQuery = @" + SELECT + COLUMN_NAME, + DATA_TYPE, + IS_NULLABLE, + CHARACTER_MAXIMUM_LENGTH, + NUMERIC_PRECISION, + NUMERIC_SCALE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = @TableName + ORDER BY ORDINAL_POSITION"; + + var columns = new List(); + var columnTypes = new Dictionary(); + + using (var command = new SqlCommand(columnQuery, _connection)) + { + command.Parameters.AddWithValue("@TableName", tableName); + using var reader = command.ExecuteReader(); + + while (reader.Read()) + { + var colName = reader.GetString(0); + var dataType = reader.GetString(1); + var isNullable = reader.GetString(2); + var maxLength = reader.IsDBNull(3) ? (int?)null : reader.GetInt32(3); + var precision = reader.IsDBNull(4) ? (byte?)null : reader.GetByte(4); + var scale = reader.IsDBNull(5) ? (int?)null : reader.GetInt32(5); + + columns.Add(colName); + + // Build type description + var typeDesc = dataType.ToUpper(); + if (maxLength.HasValue && dataType.ToLower() is "varchar" or "nvarchar" or "char" or "nchar") + { + typeDesc += $"({maxLength})"; + } + else if (precision.HasValue && dataType.ToLower() is "decimal" or "numeric") + { + typeDesc += $"({precision},{scale})"; + } + + typeDesc += isNullable == "YES" ? " NULL" : " NOT NULL"; + columnTypes[colName] = typeDesc; + } + } + + if (columns.Count == 0) + { + throw new InvalidOperationException($"Table '{tableName}' not found"); + } + + // Get row count + var countQuery = $"SELECT COUNT(*) FROM [{tableName}]"; + int rowCount; + + using (var command = new SqlCommand(countQuery, _connection)) + { + rowCount = (int)command.ExecuteScalar()!; + } + + _logger.LogInformation("Table {TableName}: {ColumnCount} columns, {RowCount} rows", tableName, columns.Count, rowCount); + + return new TableInfo + { + Name = tableName, + Columns = columns, + ColumnTypes = columnTypes, + RowCount = rowCount + }; + } + catch (Exception ex) + { + _logger.LogError("Error getting table info for {TableName}: {Message}", tableName, ex.Message); + throw; + } + } + + public (List Columns, List Data) ExportTableData(string tableName, int batchSize = 10000) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + // Get table info first + var tableInfo = GetTableInfo(tableName); + + // Build column list with proper quoting + var quotedColumns = tableInfo.Columns.Select(col => $"[{col}]").ToList(); + var columnList = string.Join(", ", quotedColumns); + + // Execute query + var query = $"SELECT {columnList} FROM [{tableName}]"; + _logger.LogInformation("Executing export query for {TableName}", tableName); + + using var command = new SqlCommand(query, _connection); + command.CommandTimeout = 300; // 5 minutes + + using var reader = command.ExecuteReader(); + + // Fetch data in batches + var allData = new List(); + while (reader.Read()) + { + var row = new object[tableInfo.Columns.Count]; + reader.GetValues(row); + allData.Add(row); + + if (allData.Count % batchSize == 0) + { + _logger.LogDebug("Fetched {Count} rows from {TableName}", allData.Count, tableName); + } + } + + // Convert GUID values to uppercase to ensure compatibility with Bitwarden + var processedData = ConvertGuidsToUppercase(allData, tableInfo); + + _logger.LogInformation("Exported {Count} rows from {TableName}", processedData.Count, tableName); + return (tableInfo.Columns, processedData); + } + catch (Exception ex) + { + _logger.LogError("Error exporting data from {TableName}: {Message}", tableName, ex.Message); + throw; + } + } + + public List IdentifyJsonColumns(string tableName, int sampleSize = 100) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var tableInfo = GetTableInfo(tableName); + var jsonColumns = new List(); + + // Only check varchar/text columns + var textColumns = tableInfo.ColumnTypes + .Where(kv => kv.Value.ToLower().Contains("varchar") || + kv.Value.ToLower().Contains("text") || + kv.Value.ToLower().Contains("nvarchar")) + .Select(kv => kv.Key) + .ToList(); + + if (textColumns.Count == 0) + return jsonColumns; + + // Sample data from text columns + var quotedColumns = textColumns.Select(col => $"[{col}]").ToList(); + var columnList = string.Join(", ", quotedColumns); + + var whereClause = string.Join(" OR ", textColumns.Select(col => $"[{col}] IS NOT NULL")); + + var query = $@" + SELECT TOP {sampleSize} {columnList} + FROM [{tableName}] + WHERE {whereClause}"; + + using var command = new SqlCommand(query, _connection); + using var reader = command.ExecuteReader(); + + var sampleData = new List(); + while (reader.Read()) + { + var row = new object[textColumns.Count]; + reader.GetValues(row); + sampleData.Add(row); + } + + // Analyze each column + for (int i = 0; i < textColumns.Count; i++) + { + var colName = textColumns[i]; + var jsonIndicators = 0; + var totalNonNull = 0; + + foreach (var row in sampleData) + { + if (i < row.Length && row[i] != DBNull.Value) + { + totalNonNull++; + var value = row[i]?.ToString()?.Trim() ?? string.Empty; + + // Check for JSON indicators + if ((value.StartsWith("{") && value.EndsWith("}")) || + (value.StartsWith("[") && value.EndsWith("]"))) + { + jsonIndicators++; + } + } + } + + // If more than 50% of non-null values look like JSON, mark as JSON column + if (totalNonNull > 0 && (double)jsonIndicators / totalNonNull > 0.5) + { + jsonColumns.Add(colName); + _logger.LogInformation("Identified {ColumnName} as likely JSON column ({JsonIndicators}/{TotalNonNull} samples)", colName, jsonIndicators, totalNonNull); + } + } + + return jsonColumns; + } + catch (Exception ex) + { + _logger.LogError("Error identifying JSON columns in {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + private List ConvertGuidsToUppercase(List data, TableInfo tableInfo) + { + if (data.Count == 0 || tableInfo.ColumnTypes.Count == 0) + return data; + + // Identify GUID columns (uniqueidentifier type in SQL Server) + var guidColumnIndices = new List(); + for (int i = 0; i < tableInfo.Columns.Count; i++) + { + var columnName = tableInfo.Columns[i]; + if (tableInfo.ColumnTypes.TryGetValue(columnName, out var columnType)) + { + if (columnType.ToUpper().Contains("UNIQUEIDENTIFIER")) + { + guidColumnIndices.Add(i); + _logger.LogDebug("Found GUID column '{ColumnName}' at index {Index}", columnName, i); + } + } + } + + if (guidColumnIndices.Count == 0) + { + _logger.LogDebug("No GUID columns found, returning data unchanged"); + return data; + } + + _logger.LogInformation("Converting {Count} GUID column(s) to uppercase", guidColumnIndices.Count); + + // Process each row and convert GUID values to uppercase + var processedData = new List(); + foreach (var row in data) + { + var rowList = row.ToList(); + foreach (var guidIdx in guidColumnIndices) + { + if (guidIdx < rowList.Count && rowList[guidIdx] != null && rowList[guidIdx] != DBNull.Value) + { + var guidValue = rowList[guidIdx].ToString(); + // Convert to uppercase, preserving the GUID format + rowList[guidIdx] = guidValue?.ToUpper() ?? string.Empty; + } + } + processedData.Add(rowList.ToArray()); + } + + return processedData; + } + + public bool TestConnection() + { + try + { + if (Connect()) + { + using var command = new SqlCommand("SELECT 1", _connection); + var result = command.ExecuteScalar(); + Disconnect(); + return result != null && (int)result == 1; + } + return false; + } + catch (Exception ex) + { + _logger.LogError("Connection test failed: {Message}", ex.Message); + return false; + } + } + + public void Dispose() + { + Disconnect(); + } +} diff --git a/util/Seeder/Migration/Databases/SqlServerImporter.cs b/util/Seeder/Migration/Databases/SqlServerImporter.cs new file mode 100644 index 0000000000..b48c835778 --- /dev/null +++ b/util/Seeder/Migration/Databases/SqlServerImporter.cs @@ -0,0 +1,671 @@ +using Microsoft.Data.SqlClient; +using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Utils; +using Microsoft.Extensions.Logging; +using System.Data; + +namespace Bit.Seeder.Migration.Databases; + +public class SqlServerImporter(DatabaseConfig config, ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _host = config.Host; + private readonly int _port = config.Port; + private readonly string _database = config.Database; + private readonly string _username = config.Username; + private readonly string _password = config.Password; + private SqlConnection? _connection; + private List<(string Schema, string Table, string Constraint)> _disabledConstraints = []; + + public bool Connect() + { + try + { + var connectionString = $"Server={_host},{_port};Database={_database};" + + $"User Id={_username};Password={_password};" + + $"TrustServerCertificate=True;Connection Timeout=30;"; + + _connection = new SqlConnection(connectionString); + _connection.Open(); + + _logger.LogInformation("Connected to SQL Server: {Host}/{Database}", _host, _database); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to connect to SQL Server: {Message}", ex.Message); + return false; + } + } + + public void Disconnect() + { + if (_connection != null) + { + _connection.Close(); + _connection.Dispose(); + _connection = null; + _logger.LogInformation("Disconnected from SQL Server"); + } + } + + public List GetTableColumns(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = @TableName + ORDER BY ORDINAL_POSITION"; + + using var command = new SqlCommand(query, _connection); + command.Parameters.AddWithValue("@TableName", tableName); + + var columns = new List(); + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columns.Add(reader.GetString(0)); + } + + return columns; + } + catch (Exception ex) + { + _logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + public bool TableExists(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT COUNT(*) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_NAME = @TableName AND TABLE_TYPE = 'BASE TABLE'"; + + using var command = new SqlCommand(query, _connection); + command.Parameters.AddWithValue("@TableName", tableName); + + var count = (int)command.ExecuteScalar()!; + return count > 0; + } + catch (Exception ex) + { + _logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message); + return false; + } + } + + public int GetTableRowCount(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"SELECT COUNT(*) FROM [{tableName}]"; + using var command = new SqlCommand(query, _connection); + + var count = (int)command.ExecuteScalar()!; + _logger.LogDebug("Row count for {TableName}: {Count}", tableName, count); + return count; + } + catch (Exception ex) + { + _logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message); + return 0; + } + } + + public bool DropTable(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"DROP TABLE IF EXISTS [{tableName}]"; + using var command = new SqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Dropped table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool DisableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Disabling foreign key constraints for SQL Server"); + + // Get all foreign key constraints + var query = @" + SELECT + OBJECT_SCHEMA_NAME(parent_object_id) AS schema_name, + OBJECT_NAME(parent_object_id) AS table_name, + name AS constraint_name + FROM sys.foreign_keys + WHERE is_disabled = 0"; + + using var command = new SqlCommand(query, _connection); + using var reader = command.ExecuteReader(); + + var constraints = new List<(string Schema, string Table, string Constraint)>(); + while (reader.Read()) + { + constraints.Add(( + reader.GetString(0), + reader.GetString(1), + reader.GetString(2) + )); + } + reader.Close(); + + // Disable each constraint + _disabledConstraints = []; + foreach (var (schema, table, constraint) in constraints) + { + try + { + var disableSql = $"ALTER TABLE [{schema}].[{table}] NOCHECK CONSTRAINT [{constraint}]"; + using var disableCommand = new SqlCommand(disableSql, _connection); + disableCommand.ExecuteNonQuery(); + + _disabledConstraints.Add((schema, table, constraint)); + _logger.LogDebug("Disabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table); + } + catch (Exception ex) + { + _logger.LogWarning("Could not disable constraint {Constraint}: {Message}", constraint, ex.Message); + } + } + + _logger.LogInformation("Disabled {Count} foreign key constraints", _disabledConstraints.Count); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + public bool EnableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Re-enabling foreign key constraints for SQL Server"); + + var enabledCount = 0; + foreach (var (schema, table, constraint) in _disabledConstraints) + { + try + { + var enableSql = $"ALTER TABLE [{schema}].[{table}] CHECK CONSTRAINT [{constraint}]"; + using var command = new SqlCommand(enableSql, _connection); + command.ExecuteNonQuery(); + + enabledCount++; + _logger.LogDebug("Re-enabled constraint: {Constraint} on {Schema}.{Table}", constraint, schema, table); + } + catch (Exception ex) + { + _logger.LogWarning("Could not re-enable constraint {Constraint}: {Message}", constraint, ex.Message); + } + } + + _logger.LogInformation("Re-enabled {EnabledCount}/{TotalCount} foreign key constraints", enabledCount, _disabledConstraints.Count); + _disabledConstraints.Clear(); + + return true; + } + catch (Exception ex) + { + _logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + public bool CreateTableFromSchema( + string tableName, + List columns, + Dictionary columnTypes, + List? specialColumns = null) + { + specialColumns ??= []; + + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + // Build column definitions + var sqlServerColumns = new List(); + foreach (var colName in columns) + { + var colType = columnTypes.GetValueOrDefault(colName, "NVARCHAR(MAX)"); + + // If it's a special JSON column, ensure it's a large text type + if (specialColumns.Contains(colName) && + !colType.ToUpper().Contains("VARCHAR(MAX)") && + !colType.ToUpper().Contains("TEXT")) + { + colType = "NVARCHAR(MAX)"; + } + + sqlServerColumns.Add($"[{colName}] {colType}"); + } + + // Build CREATE TABLE statement + var createSql = $@" + CREATE TABLE [{tableName}] ( + {string.Join(",\n ", sqlServerColumns)} + )"; + + _logger.LogInformation("Creating table {TableName} in SQL Server", tableName); + _logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql); + + using var command = new SqlCommand(createSql, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Successfully created table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public List GetIdentityColumns(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = @" + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = @TableName + AND COLUMNPROPERTY(OBJECT_ID(TABLE_SCHEMA + '.' + TABLE_NAME), COLUMN_NAME, 'IsIdentity') = 1"; + + using var command = new SqlCommand(query, _connection); + command.Parameters.AddWithValue("@TableName", tableName); + + var columns = new List(); + using var reader = command.ExecuteReader(); + while (reader.Read()) + { + columns.Add(reader.GetString(0)); + } + + return columns; + } + catch (Exception ex) + { + _logger.LogError("Error getting identity columns for table {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + public bool EnableIdentityInsert(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"SET IDENTITY_INSERT [{tableName}] ON"; + using var command = new SqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogDebug("Enabled IDENTITY_INSERT for {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error enabling IDENTITY_INSERT for {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool DisableIdentityInsert(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"SET IDENTITY_INSERT [{tableName}] OFF"; + using var command = new SqlCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogDebug("Disabled IDENTITY_INSERT for {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error disabling IDENTITY_INSERT for {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool ImportData( + string tableName, + List columns, + List data, + int batchSize = 1000) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + // Get actual table columns from SQL Server + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Filter columns and data + var validColumnIndices = new List(); + var validColumns = new List(); + var missingColumns = new List(); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumns.Contains(columns[i])) + { + validColumnIndices.Add(i); + validColumns.Add(columns[i]); + } + else + { + missingColumns.Add(columns[i]); + } + } + + if (missingColumns.Count > 0) + { + _logger.LogWarning("Skipping columns that don't exist in {TableName}: {Columns}", tableName, string.Join(", ", missingColumns)); + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + // Filter data to only include valid columns + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Valid columns for {TableName}: {Columns}", tableName, string.Join(", ", validColumns)); + + // Check if table has identity columns + var identityColumns = GetIdentityColumns(tableName); + var identityColumnsInData = validColumns.Intersect(identityColumns).ToList(); + var needsIdentityInsert = identityColumnsInData.Count > 0; + + if (needsIdentityInsert) + { + _logger.LogInformation("Table {TableName} has identity columns in import data: {Columns}", tableName, string.Join(", ", identityColumnsInData)); + _logger.LogInformation("Enabling IDENTITY_INSERT to allow explicit identity values"); + + if (!EnableIdentityInsert(tableName)) + { + _logger.LogError("Could not enable IDENTITY_INSERT for {TableName}", tableName); + return false; + } + } + + // Check for existing data + var existingCount = GetTableRowCount(tableName); + if (existingCount > 0) + { + _logger.LogWarning("Table {TableName} already contains {ExistingCount} rows - potential for primary key conflicts", tableName, existingCount); + } + + // Import using batch insert + var totalImported = FastBatchImport(tableName, validColumns, filteredData, batchSize); + + _logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName); + + // Disable IDENTITY_INSERT if it was enabled + if (needsIdentityInsert) + { + if (!DisableIdentityInsert(tableName)) + { + _logger.LogWarning("Could not disable IDENTITY_INSERT for {TableName}", tableName); + } + } + + // Validate that data was actually inserted + var actualCount = GetTableRowCount(tableName); + _logger.LogInformation("Post-import validation for {TableName}: imported {TotalImported}, table contains {ActualCount}", tableName, totalImported, actualCount); + + if (actualCount < totalImported) + { + _logger.LogError("Import validation failed for {TableName}: expected at least {Expected}, found {Actual}", tableName, totalImported, actualCount); + return false; + } + + return true; + } + catch (Exception ex) + { + _logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + private int UseSqlBulkCopy(string tableName, List columns, List data) + { + try + { + using var bulkCopy = new SqlBulkCopy(_connection!) + { + DestinationTableName = $"[{tableName}]", + BatchSize = 10000, + BulkCopyTimeout = 600 // 10 minutes + }; + + // Map columns + foreach (var column in columns) + { + bulkCopy.ColumnMappings.Add(column, column); + } + + // Create DataTable + var dataTable = new DataTable(); + foreach (var column in columns) + { + dataTable.Columns.Add(column, typeof(object)); + } + + // Add rows with data type conversion + foreach (var row in data) + { + var preparedRow = PrepareRowForInsert(row, columns); + dataTable.Rows.Add(preparedRow); + } + + _logger.LogInformation("Using SqlBulkCopy for {Count} rows", data.Count); + bulkCopy.WriteToServer(dataTable); + + return data.Count; + } + catch (Exception ex) + { + _logger.LogWarning("SqlBulkCopy failed: {Message}, falling back to batch insert", ex.Message); + return FastBatchImport(tableName, columns, data, 1000); + } + } + + private int FastBatchImport(string tableName, List columns, List data, int batchSize) + { + var quotedColumns = columns.Select(col => $"[{col}]").ToList(); + var placeholders = string.Join(", ", columns.Select((_, i) => $"@p{i}")); + var insertSql = $"INSERT INTO [{tableName}] ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})"; + + var totalImported = 0; + + for (int i = 0; i < data.Count; i += batchSize) + { + var batch = data.Skip(i).Take(batchSize).ToList(); + + using var transaction = _connection!.BeginTransaction(); + try + { + foreach (var row in batch) + { + using var command = new SqlCommand(insertSql, _connection, transaction); + + var preparedRow = PrepareRowForInsert(row, columns); + for (int p = 0; p < preparedRow.Length; p++) + { + command.Parameters.AddWithValue($"@p{p}", preparedRow[p] ?? DBNull.Value); + } + + command.ExecuteNonQuery(); + } + + transaction.Commit(); + totalImported += batch.Count; + + if (data.Count > 1000) + { + _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{DataCount} total, {Percentage:F1}%)", batch.Count, totalImported, data.Count, (totalImported / (double)data.Count * 100)); + } + } + catch + { + transaction.Rollback(); + throw; + } + } + + return totalImported; + } + + private object[] PrepareRowForInsert(object?[] row, List columns) + { + var preparedRow = new object[row.Length]; + + for (int i = 0; i < row.Length; i++) + { + preparedRow[i] = ConvertValueForSqlServer(row[i]); + } + + return preparedRow; + } + + private object ConvertValueForSqlServer(object? value) + { + if (value == null || value == DBNull.Value) + return DBNull.Value; + + // Handle string conversions + if (value is string strValue) + { + // Only convert truly empty strings to DBNull, not whitespace + // This preserves JSON strings and other data that might have whitespace + if (strValue.Length == 0) + return DBNull.Value; + + // Handle boolean-like values + if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase)) + return 1; + if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase)) + return 0; + + // Handle datetime values - SQL Server DATETIME supports 3 decimal places + if (DateTimeHelper.IsLikelyIsoDateTime(strValue)) + { + try + { + // Remove timezone if present + var datetimePart = strValue.Contains('+') || strValue.EndsWith('Z') || strValue.Contains('T') + ? DateTimeHelper.RemoveTimezone(strValue) ?? strValue + : strValue; + + // Handle microseconds - SQL Server DATETIME precision is 3.33ms, so truncate to 3 digits + if (datetimePart.Contains('.')) + { + var parts = datetimePart.Split('.'); + if (parts.Length == 2 && parts[1].Length > 3) + { + datetimePart = $"{parts[0]}.{parts[1][..3]}"; + } + } + + return datetimePart; + } + catch + { + // If conversion fails, return original value + } + } + } + + return value; + } + + public bool TestConnection() + { + try + { + if (Connect()) + { + using var command = new SqlCommand("SELECT 1", _connection); + var result = command.ExecuteScalar(); + Disconnect(); + return result != null && (int)result == 1; + } + return false; + } + catch (Exception ex) + { + _logger.LogError("SQL Server import connection test failed: {Message}", ex.Message); + return false; + } + } + + public void Dispose() + { + Disconnect(); + } +} diff --git a/util/Seeder/Migration/Databases/SqliteImporter.cs b/util/Seeder/Migration/Databases/SqliteImporter.cs new file mode 100644 index 0000000000..2b81d40372 --- /dev/null +++ b/util/Seeder/Migration/Databases/SqliteImporter.cs @@ -0,0 +1,442 @@ +using Microsoft.Data.Sqlite; +using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Utils; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Databases; + +public class SqliteImporter(DatabaseConfig config, ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _databasePath = config.Database; + private SqliteConnection? _connection; + + public bool Connect() + { + try + { + // Ensure directory exists + var directory = Path.GetDirectoryName(_databasePath); + if (!string.IsNullOrEmpty(directory)) + { + Directory.CreateDirectory(directory); + } + + var connectionString = $"Data Source={_databasePath}"; + _connection = new SqliteConnection(connectionString); + _connection.Open(); + + // Enable foreign keys and set pragmas for better performance + using (var command = new SqliteCommand("PRAGMA foreign_keys = ON", _connection)) + { + command.ExecuteNonQuery(); + } + using (var command = new SqliteCommand("PRAGMA journal_mode = WAL", _connection)) + { + command.ExecuteNonQuery(); + } + using (var command = new SqliteCommand("PRAGMA synchronous = NORMAL", _connection)) + { + command.ExecuteNonQuery(); + } + + _logger.LogInformation("Connected to SQLite database: {DatabasePath}", _databasePath); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to connect to SQLite: {Message}", ex.Message); + return false; + } + } + + public void Disconnect() + { + if (_connection != null) + { + try + { + // Force completion of any pending WAL operations + using (var command = new SqliteCommand("PRAGMA wal_checkpoint(TRUNCATE)", _connection)) + { + command.ExecuteNonQuery(); + } + } + catch (Exception ex) + { + _logger.LogWarning("Error during WAL checkpoint: {Message}", ex.Message); + } + + _connection.Close(); + _connection.Dispose(); + _connection = null; + _logger.LogInformation("Disconnected from SQLite"); + } + } + + public bool CreateTableFromSchema( + string tableName, + List columns, + Dictionary columnTypes, + List? specialColumns = null) + { + specialColumns ??= []; + + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var sqliteColumns = new List(); + foreach (var colName in columns) + { + var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)"); + var sqliteType = ConvertSqlServerTypeToSQLite(sqlServerType, specialColumns.Contains(colName)); + sqliteColumns.Add($"\"{colName}\" {sqliteType}"); + } + + var createSql = $@" + CREATE TABLE IF NOT EXISTS ""{tableName}"" ( + {string.Join(",\n ", sqliteColumns)} + )"; + + _logger.LogInformation("Creating table {TableName} in SQLite", tableName); + _logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql); + + using var command = new SqliteCommand(createSql, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Successfully created table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public List GetTableColumns(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"PRAGMA table_info(\"{tableName}\")"; + using var command = new SqliteCommand(query, _connection); + using var reader = command.ExecuteReader(); + + var columns = new List(); + while (reader.Read()) + { + columns.Add(reader.GetString(1)); // Column name is at index 1 + } + + return columns; + } + catch (Exception ex) + { + _logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message); + return []; + } + } + + public bool ImportData( + string tableName, + List columns, + List data, + int batchSize = 1000) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + if (data.Count == 0) + { + _logger.LogWarning("No data to import for table {TableName}", tableName); + return true; + } + + try + { + var actualColumns = GetTableColumns(tableName); + if (actualColumns.Count == 0) + { + _logger.LogError("Could not retrieve columns for table {TableName}", tableName); + return false; + } + + // Filter columns + var validColumnIndices = new List(); + var validColumns = new List(); + + for (int i = 0; i < columns.Count; i++) + { + if (actualColumns.Contains(columns[i])) + { + validColumnIndices.Add(i); + validColumns.Add(columns[i]); + } + } + + if (validColumns.Count == 0) + { + _logger.LogError("No valid columns found for table {TableName}", tableName); + return false; + } + + var filteredData = data.Select(row => + validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray() + ).ToList(); + + _logger.LogInformation("Importing {Count} rows into {TableName}", filteredData.Count, tableName); + + // Build INSERT statement + var quotedColumns = validColumns.Select(col => $"\"{col}\"").ToList(); + var placeholders = string.Join(", ", Enumerable.Range(0, validColumns.Count).Select(i => $"@p{i}")); + var insertSql = $"INSERT INTO \"{tableName}\" ({string.Join(", ", quotedColumns)}) VALUES ({placeholders})"; + + // Begin transaction for all batches + using var transaction = _connection.BeginTransaction(); + try + { + var totalImported = 0; + for (int i = 0; i < filteredData.Count; i += batchSize) + { + var batch = filteredData.Skip(i).Take(batchSize).ToList(); + + foreach (var row in batch) + { + using var command = new SqliteCommand(insertSql, _connection, transaction); + + var preparedRow = PrepareRowForInsert(row, validColumns); + for (int p = 0; p < preparedRow.Length; p++) + { + var value = preparedRow[p] ?? DBNull.Value; + + // For string values, explicitly set parameter type to avoid truncation + if (value is string strValue) + { + var param = command.Parameters.Add($"@p{p}", Microsoft.Data.Sqlite.SqliteType.Text); + param.Value = strValue; + } + else + { + command.Parameters.AddWithValue($"@p{p}", value); + } + } + + command.ExecuteNonQuery(); + } + + totalImported += batch.Count; + + if (filteredData.Count > 1000) + { + _logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count); + } + } + + transaction.Commit(); + + _logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName); + return true; + } + catch + { + transaction.Rollback(); + throw; + } + } + catch (Exception ex) + { + _logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool TableExists(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name = @tableName"; + using var command = new SqliteCommand(query, _connection); + command.Parameters.AddWithValue("@tableName", tableName); + + var count = Convert.ToInt64(command.ExecuteScalar()); + return count > 0; + } + catch (Exception ex) + { + _logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message); + return false; + } + } + + public int GetTableRowCount(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"SELECT COUNT(*) FROM \"{tableName}\""; + using var command = new SqliteCommand(query, _connection); + + return Convert.ToInt32(command.ExecuteScalar()); + } + catch (Exception ex) + { + _logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message); + return 0; + } + } + + public bool DropTable(string tableName) + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + var query = $"DROP TABLE IF EXISTS \"{tableName}\""; + using var command = new SqliteCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Dropped table {TableName}", tableName); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message); + return false; + } + } + + public bool DisableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Disabling foreign key constraints"); + var query = "PRAGMA foreign_keys = OFF"; + using var command = new SqliteCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints disabled"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + public bool EnableForeignKeys() + { + if (_connection == null) + throw new InvalidOperationException("Not connected to database"); + + try + { + _logger.LogInformation("Re-enabling foreign key constraints"); + var query = "PRAGMA foreign_keys = ON"; + using var command = new SqliteCommand(query, _connection); + command.ExecuteNonQuery(); + + _logger.LogInformation("Foreign key constraints re-enabled"); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message); + return false; + } + } + + private string ConvertSqlServerTypeToSQLite(string sqlServerType, bool isJsonColumn) + { + var baseType = sqlServerType.Replace(" NULL", "").Replace(" NOT NULL", "").Trim().ToUpper(); + var isNullable = !sqlServerType.Contains("NOT NULL"); + + if (isJsonColumn) + return "TEXT" + (isNullable ? "" : " NOT NULL"); + + // SQLite has only 5 storage classes: NULL, INTEGER, REAL, TEXT, BLOB + string sqliteType; + if (baseType.Contains("INT") || baseType.Contains("BIT")) + sqliteType = "INTEGER"; + else if (baseType.Contains("DECIMAL") || baseType.Contains("NUMERIC") || + baseType.Contains("FLOAT") || baseType.Contains("REAL") || baseType.Contains("MONEY")) + sqliteType = "REAL"; + else if (baseType.Contains("BINARY") || baseType == "IMAGE") + sqliteType = "BLOB"; + else + sqliteType = "TEXT"; + + return sqliteType + (isNullable ? "" : " NOT NULL"); + } + + private object[] PrepareRowForInsert(object?[] row, List columns) + { + return row.Select(value => + { + if (value == null || value == DBNull.Value) + return DBNull.Value; + + if (value is string strValue) + { + // Only convert truly empty strings to DBNull, not whitespace + // This preserves JSON strings and other data that might have whitespace + if (strValue.Length == 0) + return DBNull.Value; + + // Boolean to integer for SQLite + if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase)) + return 1; + if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase)) + return 0; + + // Handle datetime with timezone + if ((strValue.Contains('+') || strValue.EndsWith('Z')) && + DateTimeHelper.IsLikelyIsoDateTime(strValue)) + { + return DateTimeHelper.RemoveTimezone(strValue) ?? strValue; + } + } + + return value; + }).ToArray(); + } + + public bool TestConnection() + { + try + { + if (Connect()) + { + using var command = new SqliteCommand("SELECT 1", _connection); + var result = command.ExecuteScalar(); + Disconnect(); + return result != null && Convert.ToInt32(result) == 1; + } + return false; + } + catch (Exception ex) + { + _logger.LogError("SQLite connection test failed: {Message}", ex.Message); + return false; + } + } + + public void Dispose() + { + Disconnect(); + } +} diff --git a/util/Seeder/Migration/Models/Config.cs b/util/Seeder/Migration/Models/Config.cs new file mode 100644 index 0000000000..f545f8f61a --- /dev/null +++ b/util/Seeder/Migration/Models/Config.cs @@ -0,0 +1,52 @@ +namespace Bit.Seeder.Migration.Models; + +public class DatabaseConfig +{ + public string Host { get; set; } = string.Empty; + public int Port { get; set; } + public string Database { get; set; } = string.Empty; + public string Username { get; set; } = string.Empty; + public string Password { get; set; } = string.Empty; + public string? Driver { get; set; } +} + +public class CsvSettings +{ + public string OutputDir { get; set; } = "./exports"; + public string Delimiter { get; set; } = ","; + public string Quoting { get; set; } = "QUOTE_ALL"; + public string Encoding { get; set; } = "utf-8"; + public bool IncludeHeaders { get; set; } = true; + public string FallbackDelimiter { get; set; } = "|"; +} + +public class SSHTunnelConfig +{ + public bool Enabled { get; set; } = false; + public string RemoteHost { get; set; } = string.Empty; + public string RemoteUser { get; set; } = string.Empty; + public int LocalPort { get; set; } = 1433; + public int RemotePort { get; set; } = 1433; + public string PrivateKeyPath { get; set; } = "~/.ssh/id_ed25519"; + public string? PrivateKeyPassphrase { get; set; } +} + +public class MigrationConfig +{ + public DatabaseConfig? Source { get; set; } + public Dictionary Destinations { get; set; } = new(); + public Dictionary TableMappings { get; set; } = new(); + public Dictionary> SpecialColumns { get; set; } = new(); + public List ExcludeTables { get; set; } = new(); + public SSHTunnelConfig SshTunnel { get; set; } = new(); + public CsvSettings CsvSettings { get; set; } = new(); + public int BatchSize { get; set; } = 1000; +} + +public class TableInfo +{ + public string Name { get; set; } = string.Empty; + public List Columns { get; set; } = new(); + public Dictionary ColumnTypes { get; set; } = new(); + public int RowCount { get; set; } +} diff --git a/util/Seeder/Migration/Models/ReporterModels.cs b/util/Seeder/Migration/Models/ReporterModels.cs new file mode 100644 index 0000000000..1fd0ed007a --- /dev/null +++ b/util/Seeder/Migration/Models/ReporterModels.cs @@ -0,0 +1,123 @@ +namespace Bit.Seeder.Migration.Models; + +public enum ImportStatus +{ + Success, + Failed, + Skipped, + Partial +} + +public enum VerificationStatus +{ + Verified, + Mismatch, + Missing, + Error +} + +public class TableImportStats +{ + public string TableName { get; set; } = string.Empty; + public string DestinationTable { get; set; } = string.Empty; + public ImportStatus Status { get; set; } + public int RowsLoaded { get; set; } + public int ExpectedRows { get; set; } + public DateTime StartTime { get; set; } + public DateTime EndTime { get; set; } + public string? ErrorMessage { get; set; } + public string? Notes { get; set; } + + public TimeSpan Duration => EndTime - StartTime; + + public double RowsPerSecond + { + get + { + var seconds = Duration.TotalSeconds; + return seconds > 0 ? RowsLoaded / seconds : 0; + } + } +} + +public class TableVerificationStats +{ + public string TableName { get; set; } = string.Empty; + public string DestinationTable { get; set; } = string.Empty; + public VerificationStatus Status { get; set; } + public int CsvRowCount { get; set; } + public int DatabaseRowCount { get; set; } + public string? ErrorMessage { get; set; } + + public int RowDifference => DatabaseRowCount - CsvRowCount; +} + +public class ImportSummaryStats +{ + public int TotalTables { get; set; } + public int SuccessfulTables { get; set; } + public int FailedTables { get; set; } + public int SkippedTables { get; set; } + public int TotalRowsImported { get; set; } + public DateTime StartTime { get; set; } + public DateTime EndTime { get; set; } + + public TimeSpan TotalDuration => EndTime - StartTime; + public int ErrorCount => FailedTables; + public double SuccessRate => TotalTables > 0 ? (double)SuccessfulTables / TotalTables * 100 : 0; +} + +public class VerificationSummaryStats +{ + public int TotalTables { get; set; } + public int VerifiedTables { get; set; } + public int MismatchedTables { get; set; } + public int MissingTables { get; set; } + public int ErrorTables { get; set; } + + public double SuccessRate => TotalTables > 0 ? (double)VerifiedTables / TotalTables * 100 : 0; +} + +public enum ExportStatus +{ + Success, + Failed, + Skipped +} + +public class TableExportStats +{ + public string TableName { get; set; } = string.Empty; + public ExportStatus Status { get; set; } + public int RowsExported { get; set; } + public DateTime StartTime { get; set; } + public DateTime EndTime { get; set; } + public string? ErrorMessage { get; set; } + public string? Notes { get; set; } + + public TimeSpan Duration => EndTime - StartTime; + + public double RowsPerSecond + { + get + { + var seconds = Duration.TotalSeconds; + return seconds > 0 ? RowsExported / seconds : 0; + } + } +} + +public class ExportSummaryStats +{ + public int TotalTables { get; set; } + public int SuccessfulTables { get; set; } + public int FailedTables { get; set; } + public int SkippedTables { get; set; } + public int TotalRowsExported { get; set; } + public DateTime StartTime { get; set; } + public DateTime EndTime { get; set; } + + public TimeSpan TotalDuration => EndTime - StartTime; + public int ErrorCount => FailedTables; + public double SuccessRate => TotalTables > 0 ? (double)SuccessfulTables / TotalTables * 100 : 0; +} diff --git a/util/Seeder/Migration/Reporters/ExportReporter.cs b/util/Seeder/Migration/Reporters/ExportReporter.cs new file mode 100644 index 0000000000..3d68b15b48 --- /dev/null +++ b/util/Seeder/Migration/Reporters/ExportReporter.cs @@ -0,0 +1,302 @@ +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Reporters; + +public class ExportReporter(ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly List _tableStats = []; + private DateTime _exportStartTime; + private DateTime _exportEndTime; + private TableExportStats? _currentTable; + + // ANSI color codes for console output + private const string ColorGreen = "\x1b[32m"; + private const string ColorRed = "\x1b[31m"; + private const string ColorYellow = "\x1b[33m"; + private const string ColorBlue = "\x1b[34m"; + private const string ColorCyan = "\x1b[36m"; + private const string ColorBold = "\x1b[1m"; + private const string ColorReset = "\x1b[0m"; + + // Separator constants for logging + private const string Separator = "================================================================================"; + private const string ShortSeparator = "----------------------------------------"; + + public void StartExport() + { + _exportStartTime = DateTime.Now; + _tableStats.Clear(); + Console.WriteLine(Separator); + Console.WriteLine($"{ColorBold}Starting Database Export{ColorReset}"); + Console.WriteLine(Separator); + } + + public void StartTable(string tableName) + { + _currentTable = new TableExportStats + { + TableName = tableName, + StartTime = DateTime.Now, + Status = ExportStatus.Failed // Default to failed, will update on success + }; + + Console.WriteLine($"\n{ColorBlue}[TABLE]{ColorReset} {ColorBold}{tableName}{ColorReset}"); + } + + public void FinishTable(ExportStatus status, int rowsExported, string? errorMessage = null, string? notes = null) + { + if (_currentTable == null) + return; + + _currentTable.EndTime = DateTime.Now; + _currentTable.Status = status; + _currentTable.RowsExported = rowsExported; + _currentTable.ErrorMessage = errorMessage; + _currentTable.Notes = notes; + + _tableStats.Add(_currentTable); + + // Log completion status + var statusColor = status switch + { + ExportStatus.Success => ColorGreen, + ExportStatus.Failed => ColorRed, + ExportStatus.Skipped => ColorYellow, + _ => ColorReset + }; + + var statusSymbol = status switch + { + ExportStatus.Success => "✓", + ExportStatus.Failed => "✗", + ExportStatus.Skipped => "⊘", + _ => "?" + }; + + Console.WriteLine($"{statusColor}{statusSymbol} Status:{ColorReset} {status}"); + Console.WriteLine($"Rows exported: {rowsExported:N0}"); + Console.WriteLine($"Duration: {_currentTable.Duration.TotalSeconds:F2}s"); + Console.WriteLine($"Rate: {_currentTable.RowsPerSecond:F0} rows/sec"); + + if (!string.IsNullOrEmpty(errorMessage)) + { + Console.WriteLine($"{ColorRed}Error: {errorMessage}{ColorReset}"); + } + + if (!string.IsNullOrEmpty(notes)) + { + Console.WriteLine($"Notes: {notes}"); + } + + _currentTable = null; + } + + public void FinishExport() + { + _exportEndTime = DateTime.Now; + PrintDetailedReport(); + } + + public ExportSummaryStats GetSummaryStats() + { + return new ExportSummaryStats + { + TotalTables = _tableStats.Count, + SuccessfulTables = _tableStats.Count(t => t.Status == ExportStatus.Success), + FailedTables = _tableStats.Count(t => t.Status == ExportStatus.Failed), + SkippedTables = _tableStats.Count(t => t.Status == ExportStatus.Skipped), + TotalRowsExported = _tableStats.Sum(t => t.RowsExported), + StartTime = _exportStartTime, + EndTime = _exportEndTime + }; + } + + public List GetTableStats() => _tableStats.ToList(); + + public void PrintDetailedReport() + { + var summary = GetSummaryStats(); + + Console.WriteLine($"\n{Separator}"); + Console.WriteLine($"{ColorBold}Export Summary Report{ColorReset}"); + Console.WriteLine(Separator); + + // Overall statistics + Console.WriteLine($"\n{ColorBold}Overall Statistics:{ColorReset}"); + Console.WriteLine($" Total tables: {summary.TotalTables}"); + Console.WriteLine($" {ColorGreen}✓ Successful:{ColorReset} {summary.SuccessfulTables}"); + + if (summary.FailedTables > 0) + Console.WriteLine($" {ColorRed}✗ Failed:{ColorReset} {summary.FailedTables}"); + + if (summary.SkippedTables > 0) + Console.WriteLine($" {ColorYellow}⊘ Skipped:{ColorReset} {summary.SkippedTables}"); + + Console.WriteLine($" Total rows exported: {summary.TotalRowsExported:N0}"); + Console.WriteLine($" Total duration: {summary.TotalDuration.TotalMinutes:F2} minutes"); + Console.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + + // Per-table details + if (_tableStats.Count > 0) + { + // Calculate dynamic column widths based on actual data + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + // Calculate max rows text length (format: "1,234,567") + var maxRowsTextLength = _tableStats.Max(t => $"{t.RowsExported:N0}".Length); + var rowsColumnWidth = Math.Max(15, maxRowsTextLength + 2); // Minimum 15, add 2 for padding + + // Calculate total width for dynamic separator + // tableColumnWidth + space + 10 (status) + space + rowsColumnWidth + space + 12 (duration) + space + 10 (rate) + var totalWidth = tableColumnWidth + 1 + 10 + 1 + rowsColumnWidth + 1 + 12 + 1 + 10; + var dynamicSeparator = new string('=', totalWidth); + + Console.WriteLine($"\n{ColorBold}Per-Table Details:{ColorReset}"); + Console.WriteLine(dynamicSeparator); + Console.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status".PadRight(10)} {"Rows".PadRight(rowsColumnWidth)} {"Duration".PadRight(12)} {"Rate",10}"); + Console.WriteLine(dynamicSeparator); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var statusColor = stats.Status switch + { + ExportStatus.Success => ColorGreen, + ExportStatus.Failed => ColorRed, + ExportStatus.Skipped => ColorYellow, + _ => ColorReset + }; + + var statusText = $"{statusColor}{stats.Status.ToString().PadRight(10)}{ColorReset}"; + var rowsText = $"{stats.RowsExported:N0}"; + var durationText = $"{stats.Duration.TotalSeconds:F1}s"; + var rateText = $"{stats.RowsPerSecond:F0}/s"; + + Console.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {statusText} {rowsText.PadRight(rowsColumnWidth)} {durationText.PadRight(12)} {rateText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + Console.WriteLine($" {ColorRed}→ {stats.ErrorMessage}{ColorReset}"); + } + + if (!string.IsNullOrEmpty(stats.Notes)) + { + Console.WriteLine($" {ColorCyan}→ {stats.Notes}{ColorReset}"); + } + } + + Console.WriteLine(dynamicSeparator); + } + + // Failed tables summary + var failedTables = _tableStats.Where(t => t.Status == ExportStatus.Failed).ToList(); + if (failedTables.Count > 0) + { + Console.WriteLine($"\n{ColorRed}{ColorBold}Failed Tables:{ColorReset}"); + foreach (var failed in failedTables) + { + Console.WriteLine($" • {failed.TableName}: {failed.ErrorMessage}"); + } + } + + // Performance insights + if (_tableStats.Count > 0) + { + var successfulStats = _tableStats.Where(t => t.Status == ExportStatus.Success).ToList(); + var slowest = _tableStats.OrderByDescending(t => t.Duration).First(); + var fastest = _tableStats.Where(t => t.RowsExported > 0) + .OrderByDescending(t => t.RowsPerSecond) + .FirstOrDefault(); + + Console.WriteLine($"\n{ColorBold}Performance Insights:{ColorReset}"); + + if (successfulStats.Count > 0) + { + var avgRate = successfulStats.Average(t => t.RowsPerSecond); + Console.WriteLine($" Average export rate: {avgRate:F0} rows/sec"); + } + + Console.WriteLine($" Slowest table: {slowest.TableName} ({slowest.Duration.TotalSeconds:F1}s)"); + if (fastest != null) + { + Console.WriteLine($" Fastest table: {fastest.TableName} ({fastest.RowsPerSecond:F0} rows/sec)"); + } + } + + Console.WriteLine($"\n{Separator}"); + + // Final status + if (summary.FailedTables == 0) + { + Console.WriteLine($"{ColorGreen}{ColorBold}✓ Export completed successfully!{ColorReset}"); + } + else + { + Console.WriteLine($"{ColorRed}{ColorBold}✗ Export completed with {summary.FailedTables} failed table(s){ColorReset}"); + } + + Console.WriteLine($"{Separator}\n"); + } + + public void ExportReport(string filePath) + { + try + { + using var writer = new StreamWriter(filePath); + var summary = GetSummaryStats(); + + writer.WriteLine("Database Export Report"); + writer.WriteLine($"Generated: {DateTime.Now}"); + writer.WriteLine(new string('=', 80)); + writer.WriteLine(); + + writer.WriteLine("Overall Statistics:"); + writer.WriteLine($" Total tables: {summary.TotalTables}"); + writer.WriteLine($" Successful: {summary.SuccessfulTables}"); + writer.WriteLine($" Failed: {summary.FailedTables}"); + writer.WriteLine($" Skipped: {summary.SkippedTables}"); + writer.WriteLine($" Total rows exported: {summary.TotalRowsExported:N0}"); + writer.WriteLine($" Total duration: {summary.TotalDuration.TotalMinutes:F2} minutes"); + writer.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + writer.WriteLine(); + + // Calculate dynamic column width based on longest table name + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + writer.WriteLine("Per-Table Details:"); + writer.WriteLine(new string('-', 80)); + writer.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status",-10} {"Rows",-15} {"Duration",-12} {"Rate",10}"); + writer.WriteLine(new string('-', 80)); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var rowsText = $"{stats.RowsExported:N0}"; + var durationText = $"{stats.Duration.TotalSeconds:F1}s"; + var rateText = $"{stats.RowsPerSecond:F0}/s"; + + writer.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {stats.Status,-10} {rowsText,-15} {durationText,-12} {rateText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + writer.WriteLine($" Error: {stats.ErrorMessage}"); + } + + if (!string.IsNullOrEmpty(stats.Notes)) + { + writer.WriteLine($" Notes: {stats.Notes}"); + } + } + + writer.WriteLine(new string('-', 80)); + + _logger.LogInformation("Export report exported to: {FilePath}", filePath); + } + catch (Exception ex) + { + _logger.LogError("Failed to export report: {Message}", ex.Message); + } + } +} diff --git a/util/Seeder/Migration/Reporters/ImportReporter.cs b/util/Seeder/Migration/Reporters/ImportReporter.cs new file mode 100644 index 0000000000..f5f190f336 --- /dev/null +++ b/util/Seeder/Migration/Reporters/ImportReporter.cs @@ -0,0 +1,308 @@ +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Reporters; + +public class ImportReporter(ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly List _tableStats = []; + private DateTime _importStartTime; + private DateTime _importEndTime; + private TableImportStats? _currentTable; + + // ANSI color codes for console output + private const string ColorGreen = "\x1b[32m"; + private const string ColorRed = "\x1b[31m"; + private const string ColorYellow = "\x1b[33m"; + private const string ColorBlue = "\x1b[34m"; + private const string ColorCyan = "\x1b[36m"; + private const string ColorBold = "\x1b[1m"; + private const string ColorReset = "\x1b[0m"; + + // Separator constants for logging + private const string Separator = "================================================================================"; + private const string ShortSeparator = "----------------------------------------"; + + public void StartImport() + { + _importStartTime = DateTime.Now; + _tableStats.Clear(); + Console.WriteLine(Separator); + Console.WriteLine($"{ColorBold}Starting Database Import{ColorReset}"); + Console.WriteLine(Separator); + } + + public void StartTable(string tableName, string destinationTable, int expectedRows) + { + _currentTable = new TableImportStats + { + TableName = tableName, + DestinationTable = destinationTable, + ExpectedRows = expectedRows, + StartTime = DateTime.Now, + Status = ImportStatus.Failed // Default to failed, will update on success + }; + + Console.WriteLine($"\n{ColorBlue}[TABLE]{ColorReset} {ColorBold}{tableName}{ColorReset} -> {destinationTable}"); + Console.WriteLine($"Expected rows: {expectedRows:N0}"); + } + + public void FinishTable(ImportStatus status, int rowsLoaded, string? errorMessage = null, string? notes = null) + { + if (_currentTable == null) + return; + + _currentTable.EndTime = DateTime.Now; + _currentTable.Status = status; + _currentTable.RowsLoaded = rowsLoaded; + _currentTable.ErrorMessage = errorMessage; + _currentTable.Notes = notes; + + _tableStats.Add(_currentTable); + + // Log completion status + var statusColor = status switch + { + ImportStatus.Success => ColorGreen, + ImportStatus.Failed => ColorRed, + ImportStatus.Partial => ColorYellow, + ImportStatus.Skipped => ColorYellow, + _ => ColorReset + }; + + var statusSymbol = status switch + { + ImportStatus.Success => "✓", + ImportStatus.Failed => "✗", + ImportStatus.Partial => "⚠", + ImportStatus.Skipped => "⊘", + _ => "?" + }; + + Console.WriteLine($"{statusColor}{statusSymbol} Status:{ColorReset} {status}"); + Console.WriteLine($"Rows loaded: {rowsLoaded:N0} / {_currentTable.ExpectedRows:N0}"); + Console.WriteLine($"Duration: {_currentTable.Duration.TotalSeconds:F2}s"); + Console.WriteLine($"Rate: {_currentTable.RowsPerSecond:F0} rows/sec"); + + if (!string.IsNullOrEmpty(errorMessage)) + { + Console.WriteLine($"{ColorRed}Error: {errorMessage}{ColorReset}"); + } + + if (!string.IsNullOrEmpty(notes)) + { + Console.WriteLine($"Notes: {notes}"); + } + + _currentTable = null; + } + + public void FinishImport() + { + _importEndTime = DateTime.Now; + PrintDetailedReport(); + } + + public ImportSummaryStats GetSummaryStats() + { + return new ImportSummaryStats + { + TotalTables = _tableStats.Count, + SuccessfulTables = _tableStats.Count(t => t.Status == ImportStatus.Success), + FailedTables = _tableStats.Count(t => t.Status == ImportStatus.Failed), + SkippedTables = _tableStats.Count(t => t.Status == ImportStatus.Skipped), + TotalRowsImported = _tableStats.Sum(t => t.RowsLoaded), + StartTime = _importStartTime, + EndTime = _importEndTime + }; + } + + public List GetTableStats() => _tableStats.ToList(); + + public void PrintDetailedReport() + { + var summary = GetSummaryStats(); + + Console.WriteLine($"\n{Separator}"); + Console.WriteLine($"{ColorBold}Import Summary Report{ColorReset}"); + Console.WriteLine(Separator); + + // Overall statistics + Console.WriteLine($"\n{ColorBold}Overall Statistics:{ColorReset}"); + Console.WriteLine($" Total tables: {summary.TotalTables}"); + Console.WriteLine($" {ColorGreen}✓ Successful:{ColorReset} {summary.SuccessfulTables}"); + + if (summary.FailedTables > 0) + Console.WriteLine($" {ColorRed}✗ Failed:{ColorReset} {summary.FailedTables}"); + + if (summary.SkippedTables > 0) + Console.WriteLine($" {ColorYellow}⊘ Skipped:{ColorReset} {summary.SkippedTables}"); + + Console.WriteLine($" Total rows imported: {summary.TotalRowsImported:N0}"); + Console.WriteLine($" Total duration: {summary.TotalDuration.TotalMinutes:F2} minutes"); + Console.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + + // Per-table details + if (_tableStats.Count > 0) + { + // Calculate dynamic column widths based on actual data + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + // Calculate max rows text length (format: "1,234/5,678") + var maxRowsTextLength = _tableStats.Max(t => $"{t.RowsLoaded:N0}/{t.ExpectedRows:N0}".Length); + var rowsColumnWidth = Math.Max(15, maxRowsTextLength + 2); // Minimum 15, add 2 for padding + + // Calculate total width for dynamic separator + // tableColumnWidth + space + 10 (status) + space + rowsColumnWidth + space + 12 (duration) + space + 10 (rate) + var totalWidth = tableColumnWidth + 1 + 10 + 1 + rowsColumnWidth + 1 + 12 + 1 + 10; + var dynamicSeparator = new string('=', totalWidth); + + Console.WriteLine($"\n{ColorBold}Per-Table Details:{ColorReset}"); + Console.WriteLine(dynamicSeparator); + Console.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status".PadRight(10)} {"Rows".PadRight(rowsColumnWidth)} {"Duration".PadRight(12)} {"Rate",10}"); + Console.WriteLine(dynamicSeparator); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var statusColor = stats.Status switch + { + ImportStatus.Success => ColorGreen, + ImportStatus.Failed => ColorRed, + ImportStatus.Partial => ColorYellow, + ImportStatus.Skipped => ColorYellow, + _ => ColorReset + }; + + var statusText = $"{statusColor}{stats.Status.ToString().PadRight(10)}{ColorReset}"; + var rowsText = $"{stats.RowsLoaded:N0}/{stats.ExpectedRows:N0}"; + var durationText = $"{stats.Duration.TotalSeconds:F1}s"; + var rateText = $"{stats.RowsPerSecond:F0}/s"; + + Console.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {statusText} {rowsText.PadRight(rowsColumnWidth)} {durationText.PadRight(12)} {rateText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + Console.WriteLine($" {ColorRed}→ {stats.ErrorMessage}{ColorReset}"); + } + + if (!string.IsNullOrEmpty(stats.Notes)) + { + Console.WriteLine($" {ColorCyan}→ {stats.Notes}{ColorReset}"); + } + } + + Console.WriteLine(dynamicSeparator); + } + + // Failed tables summary + var failedTables = _tableStats.Where(t => t.Status == ImportStatus.Failed).ToList(); + if (failedTables.Count > 0) + { + Console.WriteLine($"\n{ColorRed}{ColorBold}Failed Tables:{ColorReset}"); + foreach (var failed in failedTables) + { + Console.WriteLine($" • {failed.TableName}: {failed.ErrorMessage}"); + } + } + + // Performance insights + if (_tableStats.Count > 0) + { + var successfulStats = _tableStats.Where(t => t.Status == ImportStatus.Success).ToList(); + var slowest = _tableStats.OrderByDescending(t => t.Duration).First(); + var fastest = _tableStats.Where(t => t.RowsLoaded > 0) + .OrderByDescending(t => t.RowsPerSecond) + .FirstOrDefault(); + + Console.WriteLine($"\n{ColorBold}Performance Insights:{ColorReset}"); + + if (successfulStats.Count > 0) + { + var avgRate = successfulStats.Average(t => t.RowsPerSecond); + Console.WriteLine($" Average import rate: {avgRate:F0} rows/sec"); + } + + Console.WriteLine($" Slowest table: {slowest.TableName} ({slowest.Duration.TotalSeconds:F1}s)"); + if (fastest != null) + { + Console.WriteLine($" Fastest table: {fastest.TableName} ({fastest.RowsPerSecond:F0} rows/sec)"); + } + } + + Console.WriteLine($"\n{Separator}"); + + // Final status + if (summary.FailedTables == 0) + { + Console.WriteLine($"{ColorGreen}{ColorBold}✓ Import completed successfully!{ColorReset}"); + } + else + { + Console.WriteLine($"{ColorRed}{ColorBold}✗ Import completed with {summary.FailedTables} failed table(s){ColorReset}"); + } + + Console.WriteLine($"{Separator}\n"); + } + + public void ExportReport(string filePath) + { + try + { + using var writer = new StreamWriter(filePath); + var summary = GetSummaryStats(); + + writer.WriteLine("Database Import Report"); + writer.WriteLine($"Generated: {DateTime.Now}"); + writer.WriteLine(new string('=', 80)); + writer.WriteLine(); + + writer.WriteLine("Overall Statistics:"); + writer.WriteLine($" Total tables: {summary.TotalTables}"); + writer.WriteLine($" Successful: {summary.SuccessfulTables}"); + writer.WriteLine($" Failed: {summary.FailedTables}"); + writer.WriteLine($" Skipped: {summary.SkippedTables}"); + writer.WriteLine($" Total rows imported: {summary.TotalRowsImported:N0}"); + writer.WriteLine($" Total duration: {summary.TotalDuration.TotalMinutes:F2} minutes"); + writer.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + writer.WriteLine(); + + // Calculate dynamic column width based on longest table name + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + writer.WriteLine("Per-Table Details:"); + writer.WriteLine(new string('-', 80)); + writer.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status",-10} {"Rows",-15} {"Duration",-12} {"Rate",10}"); + writer.WriteLine(new string('-', 80)); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var rowsText = $"{stats.RowsLoaded:N0}/{stats.ExpectedRows:N0}"; + var durationText = $"{stats.Duration.TotalSeconds:F1}s"; + var rateText = $"{stats.RowsPerSecond:F0}/s"; + + writer.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {stats.Status,-10} {rowsText,-15} {durationText,-12} {rateText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + writer.WriteLine($" Error: {stats.ErrorMessage}"); + } + + if (!string.IsNullOrEmpty(stats.Notes)) + { + writer.WriteLine($" Notes: {stats.Notes}"); + } + } + + writer.WriteLine(new string('-', 80)); + + _logger.LogInformation("Import report exported to: {FilePath}", filePath); + } + catch (Exception ex) + { + _logger.LogError("Failed to export report: {Message}", ex.Message); + } + } +} diff --git a/util/Seeder/Migration/Reporters/VerificationReporter.cs b/util/Seeder/Migration/Reporters/VerificationReporter.cs new file mode 100644 index 0000000000..33ff7b9260 --- /dev/null +++ b/util/Seeder/Migration/Reporters/VerificationReporter.cs @@ -0,0 +1,299 @@ +using Bit.Seeder.Migration.Models; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration.Reporters; + +public class VerificationReporter(ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly List _tableStats = []; + + // ANSI color codes for console output + private const string ColorGreen = "\x1b[32m"; + private const string ColorRed = "\x1b[31m"; + private const string ColorYellow = "\x1b[33m"; + private const string ColorBlue = "\x1b[34m"; + private const string ColorBold = "\x1b[1m"; + private const string ColorReset = "\x1b[0m"; + + // Separator constants for logging + private const string Separator = "================================================================================"; + private const string ShortSeparator = "----------------------------------------"; + + public void StartVerification() + { + _tableStats.Clear(); + Console.WriteLine(Separator); + Console.WriteLine($"{ColorBold}Starting Import Verification{ColorReset}"); + Console.WriteLine(Separator); + } + + public void VerifyTable( + string tableName, + string destinationTable, + int csvRowCount, + int databaseRowCount, + string? errorMessage = null) + { + var status = DetermineStatus(csvRowCount, databaseRowCount, errorMessage); + + var stats = new TableVerificationStats + { + TableName = tableName, + DestinationTable = destinationTable, + CsvRowCount = csvRowCount, + DatabaseRowCount = databaseRowCount, + Status = status, + ErrorMessage = errorMessage + }; + + _tableStats.Add(stats); + + // Log verification result + var statusColor = status switch + { + VerificationStatus.Verified => ColorGreen, + VerificationStatus.Mismatch => ColorRed, + VerificationStatus.Missing => ColorYellow, + VerificationStatus.Error => ColorRed, + _ => ColorReset + }; + + var statusSymbol = status switch + { + VerificationStatus.Verified => "✓", + VerificationStatus.Mismatch => "✗", + VerificationStatus.Missing => "?", + VerificationStatus.Error => "!", + _ => "?" + }; + + Console.WriteLine($"\n{ColorBlue}[TABLE]{ColorReset} {ColorBold}{tableName}{ColorReset} -> {destinationTable}"); + Console.WriteLine($"{statusColor}{statusSymbol} Status:{ColorReset} {status}"); + Console.WriteLine($"CSV rows: {csvRowCount:N0}"); + Console.WriteLine($"Database rows: {databaseRowCount:N0}"); + + if (stats.RowDifference != 0) + { + var diffColor = stats.RowDifference > 0 ? ColorGreen : ColorRed; + Console.WriteLine($"Difference: {diffColor}{stats.RowDifference:+#;-#;0}{ColorReset}"); + } + + if (!string.IsNullOrEmpty(errorMessage)) + { + Console.WriteLine($"{ColorRed}Error: {errorMessage}{ColorReset}"); + } + } + + public void FinishVerification() + { + PrintVerificationReport(); + } + + public VerificationSummaryStats GetSummaryStats() + { + return new VerificationSummaryStats + { + TotalTables = _tableStats.Count, + VerifiedTables = _tableStats.Count(t => t.Status == VerificationStatus.Verified), + MismatchedTables = _tableStats.Count(t => t.Status == VerificationStatus.Mismatch), + MissingTables = _tableStats.Count(t => t.Status == VerificationStatus.Missing), + ErrorTables = _tableStats.Count(t => t.Status == VerificationStatus.Error) + }; + } + + public List GetTableStats() => _tableStats.ToList(); + + public void PrintVerificationReport() + { + var summary = GetSummaryStats(); + + Console.WriteLine($"\n{Separator}"); + Console.WriteLine($"{ColorBold}Verification Summary Report{ColorReset}"); + Console.WriteLine(Separator); + + // Overall statistics + Console.WriteLine($"\n{ColorBold}Overall Statistics:{ColorReset}"); + Console.WriteLine($" Total tables: {summary.TotalTables}"); + Console.WriteLine($" {ColorGreen}✓ Verified:{ColorReset} {summary.VerifiedTables}"); + + if (summary.MismatchedTables > 0) + Console.WriteLine($" {ColorRed}✗ Mismatched:{ColorReset} {summary.MismatchedTables}"); + + if (summary.MissingTables > 0) + Console.WriteLine($" {ColorYellow}? Missing:{ColorReset} {summary.MissingTables}"); + + if (summary.ErrorTables > 0) + Console.WriteLine($" {ColorRed}! Errors:{ColorReset} {summary.ErrorTables}"); + + Console.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + + // Per-table details + if (_tableStats.Count > 0) + { + // Calculate dynamic column widths based on actual data + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + // Calculate max text lengths for numeric columns + var maxCsvTextLength = _tableStats.Max(t => $"{t.CsvRowCount:N0}".Length); + var maxDbTextLength = _tableStats.Max(t => $"{t.DatabaseRowCount:N0}".Length); + var csvColumnWidth = Math.Max(10, maxCsvTextLength + 2); // Minimum 10, add 2 for padding + var dbColumnWidth = Math.Max(10, maxDbTextLength + 2); // Minimum 10, add 2 for padding + + // Calculate total width for dynamic separator + // tableColumnWidth + space + 12 (status) + space + csvColumnWidth + space + dbColumnWidth + space + 10 (diff) + var totalWidth = tableColumnWidth + 1 + 12 + 1 + csvColumnWidth + 1 + dbColumnWidth + 1 + 10; + var dynamicSeparator = new string('=', totalWidth); + + Console.WriteLine($"\n{ColorBold}Per-Table Details:{ColorReset}"); + Console.WriteLine(dynamicSeparator); + Console.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status".PadRight(12)} {"CSV Rows".PadLeft(csvColumnWidth)} {"DB Rows".PadLeft(dbColumnWidth)} {"Diff",10}"); + Console.WriteLine(dynamicSeparator); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var statusColor = stats.Status switch + { + VerificationStatus.Verified => ColorGreen, + VerificationStatus.Mismatch => ColorRed, + VerificationStatus.Missing => ColorYellow, + VerificationStatus.Error => ColorRed, + _ => ColorReset + }; + + var statusText = $"{statusColor}{stats.Status.ToString().PadRight(12)}{ColorReset}"; + var csvText = $"{stats.CsvRowCount:N0}"; + var dbText = $"{stats.DatabaseRowCount:N0}"; + var diffText = stats.RowDifference != 0 + ? $"{(stats.RowDifference > 0 ? ColorGreen : ColorRed)}{stats.RowDifference:+#;-#;0}{ColorReset}" + : "0"; + + Console.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {statusText} {csvText.PadLeft(csvColumnWidth)} {dbText.PadLeft(dbColumnWidth)} {diffText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + Console.WriteLine($" {ColorRed}→ {stats.ErrorMessage}{ColorReset}"); + } + } + + Console.WriteLine(dynamicSeparator); + } + + // Problem tables + var problemTables = _tableStats + .Where(t => t.Status != VerificationStatus.Verified) + .ToList(); + + if (problemTables.Count > 0) + { + Console.WriteLine($"\n{ColorRed}{ColorBold}Tables Needing Attention:{ColorReset}"); + + foreach (var problem in problemTables) + { + var issueType = problem.Status switch + { + VerificationStatus.Mismatch => "Row count mismatch", + VerificationStatus.Missing => "CSV file not found", + VerificationStatus.Error => "Verification error", + _ => "Unknown issue" + }; + + Console.WriteLine($" • {problem.TableName}: {issueType}"); + + if (problem.Status == VerificationStatus.Mismatch) + { + Console.WriteLine($" Expected: {problem.CsvRowCount:N0}, Found: {problem.DatabaseRowCount:N0}"); + } + + if (!string.IsNullOrEmpty(problem.ErrorMessage)) + { + Console.WriteLine($" Error: {problem.ErrorMessage}"); + } + } + } + + Console.WriteLine($"\n{Separator}"); + + // Final status + if (summary.MismatchedTables == 0 && summary.ErrorTables == 0 && summary.MissingTables == 0) + { + Console.WriteLine($"{ColorGreen}{ColorBold}✓ All tables verified successfully!{ColorReset}"); + } + else + { + var problemCount = summary.MismatchedTables + summary.ErrorTables + summary.MissingTables; + Console.WriteLine($"{ColorRed}{ColorBold}✗ Verification completed with {problemCount} issue(s){ColorReset}"); + } + + Console.WriteLine($"{Separator}\n"); + } + + public void ExportReport(string filePath) + { + try + { + using var writer = new StreamWriter(filePath); + var summary = GetSummaryStats(); + + writer.WriteLine("Database Verification Report"); + writer.WriteLine($"Generated: {DateTime.Now}"); + writer.WriteLine(new string('=', 80)); + writer.WriteLine(); + + writer.WriteLine("Overall Statistics:"); + writer.WriteLine($" Total tables: {summary.TotalTables}"); + writer.WriteLine($" Verified: {summary.VerifiedTables}"); + writer.WriteLine($" Mismatched: {summary.MismatchedTables}"); + writer.WriteLine($" Missing: {summary.MissingTables}"); + writer.WriteLine($" Errors: {summary.ErrorTables}"); + writer.WriteLine($" Success rate: {summary.SuccessRate:F1}%"); + writer.WriteLine(); + + // Calculate dynamic column width based on longest table name + var maxTableNameLength = _tableStats.Max(t => t.TableName.Length); + var tableColumnWidth = Math.Max(30, maxTableNameLength + 2); // Minimum 30, add 2 for padding + + writer.WriteLine("Per-Table Details:"); + writer.WriteLine(new string('-', 80)); + writer.WriteLine($"{"Table".PadRight(tableColumnWidth)} {"Status",-12} {"CSV Rows",12} {"DB Rows",12} {"Diff",10}"); + writer.WriteLine(new string('-', 80)); + + foreach (var stats in _tableStats.OrderBy(t => t.TableName)) + { + var csvText = $"{stats.CsvRowCount:N0}"; + var dbText = $"{stats.DatabaseRowCount:N0}"; + var diffText = stats.RowDifference != 0 ? $"{stats.RowDifference:+#;-#;0}" : "0"; + + writer.WriteLine($"{stats.TableName.PadRight(tableColumnWidth)} {stats.Status,-12} {csvText,12} {dbText,12} {diffText,10}"); + + if (!string.IsNullOrEmpty(stats.ErrorMessage)) + { + writer.WriteLine($" Error: {stats.ErrorMessage}"); + } + } + + writer.WriteLine(new string('-', 80)); + + _logger.LogInformation("Verification report exported to: {FilePath}", filePath); + } + catch (Exception ex) + { + _logger.LogError("Failed to export report: {Message}", ex.Message); + } + } + + private static VerificationStatus DetermineStatus(int csvRowCount, int databaseRowCount, string? errorMessage) + { + if (!string.IsNullOrEmpty(errorMessage)) + return VerificationStatus.Error; + + if (csvRowCount < 0) + return VerificationStatus.Missing; + + if (csvRowCount == databaseRowCount) + return VerificationStatus.Verified; + + return VerificationStatus.Mismatch; + } +} diff --git a/util/Seeder/Migration/SchemaMapper.cs b/util/Seeder/Migration/SchemaMapper.cs new file mode 100644 index 0000000000..0190a73b21 --- /dev/null +++ b/util/Seeder/Migration/SchemaMapper.cs @@ -0,0 +1,209 @@ +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration; + +public class SchemaMapper( + Dictionary tableMappings, + Dictionary> specialColumns, + ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly Dictionary _tableMappings = tableMappings ?? []; + private readonly Dictionary> _specialColumns = specialColumns ?? []; + private readonly Dictionary _reverseMappings = (tableMappings ?? []).ToDictionary(kv => kv.Value, kv => kv.Key); + + public string GetDestinationTableName(string sourceTable, string? destinationDbType = null) + { + // For SQL Server to SQL Server, don't apply table mappings (schema is identical) + if (destinationDbType == "sqlserver") + { + _logger.LogDebug("SQL Server destination: keeping original table name {SourceTable}", sourceTable); + return sourceTable; + } + + // For other databases, apply configured mappings + var mappedName = _tableMappings.GetValueOrDefault(sourceTable, sourceTable); + + if (mappedName != sourceTable) + { + _logger.LogDebug("Mapped table {SourceTable} -> {MappedName} for {DestinationDbType}", sourceTable, mappedName, destinationDbType); + } + + return mappedName; + } + + public string GetSourceTableName(string destinationTable) => + _reverseMappings.GetValueOrDefault(destinationTable, destinationTable); + + public List GetSpecialColumnsForTable(string tableName) + { + var specialCols = _specialColumns.GetValueOrDefault(tableName, []); + + if (specialCols.Count > 0) + { + _logger.LogDebug("Table {TableName} has special columns: {Columns}", tableName, string.Join(", ", specialCols)); + } + + return specialCols; + } + + public bool IsSpecialColumn(string tableName, string columnName) + { + var specialCols = GetSpecialColumnsForTable(tableName); + return specialCols.Contains(columnName); + } + + public Dictionary SuggestTableMappings(List sourceTables) + { + var suggestions = new Dictionary(); + + foreach (var table in sourceTables) + { + // Check if table is singular and suggest plural + if (!table.EndsWith("s") && !table.EndsWith("es")) + { + string suggested; + + if (table.EndsWith("y")) + { + // Company -> Companies + suggested = table[..^1] + "ies"; + } + else if (table.EndsWith("s") || table.EndsWith("sh") || table.EndsWith("ch") || + table.EndsWith("x") || table.EndsWith("z")) + { + // Class -> Classes, Box -> Boxes + suggested = table + "es"; + } + else if (table.EndsWith("f")) + { + // Shelf -> Shelves + suggested = table[..^1] + "ves"; + } + else if (table.EndsWith("fe")) + { + // Life -> Lives + suggested = table[..^2] + "ves"; + } + else + { + // User -> Users + suggested = table + "s"; + } + + suggestions[table] = suggested; + } + } + + if (suggestions.Count > 0) + { + _logger.LogInformation("Suggested table mappings (singular -> plural):"); + foreach (var (source, dest) in suggestions) + { + _logger.LogInformation(" {Source} -> {Dest}", source, dest); + } + } + + return suggestions; + } + + public bool ValidateMappings(List sourceTables) + { + var sourceSet = new HashSet(sourceTables); + var invalidMappings = new List(); + + foreach (var sourceTable in _tableMappings.Keys) + { + if (!sourceSet.Contains(sourceTable)) + { + invalidMappings.Add(sourceTable); + } + } + + if (invalidMappings.Count > 0) + { + _logger.LogError("Invalid table mappings found: {InvalidMappings}", string.Join(", ", invalidMappings)); + _logger.LogError("Available source tables: {SourceTables}", string.Join(", ", sourceTables.OrderBy(t => t))); + return false; + } + + _logger.LogInformation("All table mappings are valid"); + return true; + } + + public void AddMapping(string sourceTable, string destinationTable) + { + _tableMappings[sourceTable] = destinationTable; + _reverseMappings[destinationTable] = sourceTable; + _logger.LogInformation("Added mapping: {SourceTable} -> {DestinationTable}", sourceTable, destinationTable); + } + + public void AddSpecialColumns(string tableName, List columns) + { + if (!_specialColumns.ContainsKey(tableName)) + { + _specialColumns[tableName] = []; + } + + _specialColumns[tableName].AddRange(columns); + // Remove duplicates + _specialColumns[tableName] = _specialColumns[tableName].Distinct().ToList(); + + _logger.LogInformation("Added special columns to {TableName}: {Columns}", tableName, string.Join(", ", columns)); + } + + public Dictionary> DetectNamingPatterns(List tableNames) + { + var patterns = new Dictionary> + { + ["singular"] = [], + ["plural"] = [], + ["mixed_case"] = [], + ["snake_case"] = [], + ["camel_case"] = [], + ["all_caps"] = [], + ["all_lower"] = [] + }; + + foreach (var table in tableNames) + { + // Case patterns + if (table.All(char.IsUpper)) + patterns["all_caps"].Add(table); + else if (table.All(char.IsLower)) + patterns["all_lower"].Add(table); + else if (table.Contains('_')) + patterns["snake_case"].Add(table); + else if (table.Skip(1).Any(char.IsUpper)) + patterns["camel_case"].Add(table); + else + patterns["mixed_case"].Add(table); + + // Singular/plural detection (simple heuristic) + if (table.EndsWith("s") || table.EndsWith("es") || table.EndsWith("ies")) + patterns["plural"].Add(table); + else + patterns["singular"].Add(table); + } + + // Log pattern analysis + _logger.LogInformation("Table naming pattern analysis:"); + foreach (var (pattern, tables) in patterns.Where(p => p.Value.Count > 0)) + { + var preview = string.Join(", ", tables.Take(3)); + var ellipsis = tables.Count > 3 ? "..." : ""; + _logger.LogInformation(" {Pattern}: {Count} tables - {Preview}{Ellipsis}", pattern, tables.Count, preview, ellipsis); + } + + return patterns; + } + + public void LogInitialization() + { + _logger.LogInformation("Initialized schema mapper with {Count} table mappings", _tableMappings.Count); + foreach (var (source, dest) in _tableMappings) + { + _logger.LogInformation(" {Source} -> {Dest}", source, dest); + } + } +} diff --git a/util/Seeder/Migration/TableFilter.cs b/util/Seeder/Migration/TableFilter.cs new file mode 100644 index 0000000000..05ae0657f4 --- /dev/null +++ b/util/Seeder/Migration/TableFilter.cs @@ -0,0 +1,209 @@ +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Migration; + +public class TableFilter( + List? includeTables, + List? excludeTables, + List? permanentExclusions, + ILogger logger) +{ + private readonly ILogger _logger = logger; + private readonly HashSet _includeTables = includeTables?.ToHashSet() ?? []; + private readonly HashSet _adHocExcludeTables = excludeTables?.ToHashSet() ?? []; + private readonly HashSet _permanentExclusions = permanentExclusions?.ToHashSet() ?? []; + private readonly HashSet _excludeTables = InitializeExcludeTables(excludeTables, permanentExclusions, includeTables, logger); + + private static HashSet InitializeExcludeTables( + List? excludeTables, + List? permanentExclusions, + List? includeTables, + ILogger logger) + { + var adHocExcludeSet = excludeTables?.ToHashSet() ?? []; + var permanentExcludeSet = permanentExclusions?.ToHashSet() ?? []; + var includeSet = includeTables?.ToHashSet() ?? []; + + var result = new HashSet(adHocExcludeSet); + result.UnionWith(permanentExcludeSet); + + // Remove any permanently excluded tables from include list + if (includeSet.Count > 0 && permanentExcludeSet.Count > 0) + { + var conflictingIncludes = includeSet.Intersect(permanentExcludeSet).ToList(); + if (conflictingIncludes.Count > 0) + { + logger.LogWarning("Removing permanently excluded tables from include list: {Tables}", string.Join(", ", conflictingIncludes.OrderBy(t => t))); + includeSet.ExceptWith(conflictingIncludes); + } + } + + // Validate that both include and exclude aren't used together (for ad-hoc only) + if (includeSet.Count > 0 && adHocExcludeSet.Count > 0) + { + logger.LogWarning("Both include and ad-hoc exclude tables specified. Include takes precedence over ad-hoc exclusions."); + return new HashSet(permanentExcludeSet); + } + + return result; + } + + public void LogFilterSetup() + { + if (_includeTables.Count > 0) + { + _logger.LogInformation("Table filter: INCLUDING only {Count} tables: {Tables}", _includeTables.Count, string.Join(", ", _includeTables.OrderBy(t => t))); + if (_permanentExclusions.Count > 0) + { + _logger.LogInformation("Plus permanently excluding {Count} tables: {Tables}", _permanentExclusions.Count, string.Join(", ", _permanentExclusions.OrderBy(t => t))); + } + } + else if (_excludeTables.Count > 0) + { + if (_permanentExclusions.Count > 0 && _adHocExcludeTables.Count > 0) + { + _logger.LogInformation("Table filter: EXCLUDING {Count} tables total:", _excludeTables.Count); + _logger.LogInformation(" - Permanent exclusions: {Tables}", string.Join(", ", _permanentExclusions.OrderBy(t => t))); + _logger.LogInformation(" - Ad-hoc exclusions: {Tables}", string.Join(", ", _adHocExcludeTables.OrderBy(t => t))); + } + else if (_permanentExclusions.Count > 0) + { + _logger.LogInformation("Table filter: EXCLUDING {Count} permanent tables: {Tables}", _permanentExclusions.Count, string.Join(", ", _permanentExclusions.OrderBy(t => t))); + } + else if (_adHocExcludeTables.Count > 0) + { + _logger.LogInformation("Table filter: EXCLUDING {Count} ad-hoc tables: {Tables}", _adHocExcludeTables.Count, string.Join(", ", _adHocExcludeTables.OrderBy(t => t))); + } + } + else + { + _logger.LogInformation("Table filter: No filtering applied (processing all tables)"); + } + } + + public bool ShouldProcessTable(string tableName) + { + // If include list is specified, only process tables in that list + if (_includeTables.Count > 0) + { + var result = _includeTables.Contains(tableName); + if (!result) + { + _logger.LogDebug("Skipping table {TableName} (not in include list)", tableName); + } + return result; + } + + // If exclude list is specified, process all tables except those in the list + if (_excludeTables.Count > 0) + { + var result = !_excludeTables.Contains(tableName); + if (!result) + { + _logger.LogDebug("Skipping table {TableName} (in exclude list)", tableName); + } + return result; + } + + // No filtering - process all tables + return true; + } + + public List FilterTableList(List allTables) + { + var originalCount = allTables.Count; + var filteredTables = allTables.Where(ShouldProcessTable).ToList(); + + _logger.LogInformation("Table filtering result: {FilteredCount}/{OriginalCount} tables selected for processing", filteredTables.Count, originalCount); + + if (_includeTables.Count > 0) + { + // Check if any requested include tables are missing + var availableSet = new HashSet(allTables); + var missingTables = _includeTables.Except(availableSet).ToList(); + if (missingTables.Count > 0) + { + _logger.LogWarning("Requested tables not found: {Tables}", string.Join(", ", missingTables.OrderBy(t => t))); + _logger.LogInformation("Available tables: {Tables}", string.Join(", ", allTables.OrderBy(t => t))); + } + } + + return filteredTables; + } + + public string GetFilterDescription() + { + if (_includeTables.Count > 0) + { + var baseDesc = $"Including only: {string.Join(", ", _includeTables.OrderBy(t => t))}"; + if (_permanentExclusions.Count > 0) + { + baseDesc += $" (plus {_permanentExclusions.Count} permanent exclusions)"; + } + return baseDesc; + } + + if (_excludeTables.Count > 0) + { + if (_permanentExclusions.Count > 0 && _adHocExcludeTables.Count > 0) + { + return $"Excluding: {_permanentExclusions.Count} permanent + {_adHocExcludeTables.Count} ad-hoc tables"; + } + if (_permanentExclusions.Count > 0) + { + return $"Excluding: {string.Join(", ", _permanentExclusions.OrderBy(t => t))} (permanent)"; + } + return $"Excluding: {string.Join(", ", _excludeTables.OrderBy(t => t))}"; + } + + if (_permanentExclusions.Count > 0) + { + return $"No additional filtering (permanent exclusions: {string.Join(", ", _permanentExclusions.OrderBy(t => t))})"; + } + + return "No table filtering applied"; + } + + public bool ValidateTablesExist(List availableTables) + { + var availableSet = new HashSet(availableTables); + var issuesFound = false; + + if (_includeTables.Count > 0) + { + var missingInclude = _includeTables.Except(availableSet).ToList(); + if (missingInclude.Count > 0) + { + _logger.LogError("Include tables not found: {Tables}", string.Join(", ", missingInclude.OrderBy(t => t))); + issuesFound = true; + } + } + + if (_excludeTables.Count > 0) + { + var missingExclude = _excludeTables.Except(availableSet).ToList(); + if (missingExclude.Count > 0) + { + _logger.LogWarning("Exclude tables not found (will be ignored): {Tables}", string.Join(", ", missingExclude.OrderBy(t => t))); + } + } + + return !issuesFound; + } + + public static List ParseTableList(string? tableString) + { + if (string.IsNullOrWhiteSpace(tableString)) + return []; + + // Split by comma and trim whitespace + return tableString.Split(',') + .Select(t => t.Trim()) + .Where(t => !string.IsNullOrEmpty(t)) + .ToList(); + } + + public List? GetIncludeTables() => _includeTables.Count > 0 ? _includeTables.ToList() : null; + + public List? GetExcludeTables() => _adHocExcludeTables.Count > 0 ? _adHocExcludeTables.ToList() : null; +} diff --git a/util/Seeder/Migration/Utils/DateTimeHelper.cs b/util/Seeder/Migration/Utils/DateTimeHelper.cs new file mode 100644 index 0000000000..98e08c79f0 --- /dev/null +++ b/util/Seeder/Migration/Utils/DateTimeHelper.cs @@ -0,0 +1,56 @@ +namespace Bit.Seeder.Migration.Utils; + +public static class DateTimeHelper +{ + /// + /// Checks if a string looks like an ISO datetime format (YYYY-MM-DD or variations with time). + /// This prevents false positives with data containing '+' signs (like base64 encoded strings). + /// + public static bool IsLikelyIsoDateTime(string value) + { + if (string.IsNullOrEmpty(value)) + return false; + + // Must have reasonable length for a datetime string + if (value.Length < 10 || value.Length > 35) + return false; + + // Must start with a digit + if (!char.IsDigit(value[0])) + return false; + + // Must contain a dash (date separator) + if (!value.Contains('-')) + return false; + + // Dash should be at position 2 (MM-DD-...) or 4 (YYYY-MM-...) + var dashIndex = value.IndexOf('-'); + if (dashIndex != 2 && dashIndex != 4) + return false; + + // Must have either 'T' separator or both ':' and '-' for datetime + return value.Contains('T') || (value.Contains(':') && value.Contains('-')); + } + + /// + /// Extracts the datetime portion from an ISO datetime string, removing timezone info. + /// + public static string? RemoveTimezone(string value) + { + if (string.IsNullOrEmpty(value)) + return null; + + var result = value; + + // Remove timezone offset (e.g., +00:00, -05:00) + if (result.Contains('+')) + result = result.Split('+')[0]; + else if (result.EndsWith('Z')) + result = result[..^1]; + + // Convert ISO 'T' separator to space for SQL compatibility + result = result.Replace('T', ' '); + + return result; + } +} diff --git a/util/Seeder/Migration/Utils/SecuritySanitizer.cs b/util/Seeder/Migration/Utils/SecuritySanitizer.cs new file mode 100644 index 0000000000..38ee6356f2 --- /dev/null +++ b/util/Seeder/Migration/Utils/SecuritySanitizer.cs @@ -0,0 +1,70 @@ +using System.Text.RegularExpressions; + +namespace Bit.Seeder.Migration.Utils; + +public static class SecuritySanitizer +{ + private static readonly string[] SensitiveFields = + [ + "password", "passwd", "pwd", "secret", "key", "token", + "api_key", "auth_token", "access_token", "private_key" + ]; + + public static string MaskPassword(string password) + { + if (string.IsNullOrEmpty(password)) + return string.Empty; + + if (password.Length <= 4) + return "***"; + + return password[..2] + new string('*', password.Length - 4) + password[^2..]; + } + + public static Dictionary SanitizeConfigForDisplay(Dictionary configDict) + { + var sanitized = new Dictionary(); + + foreach (var (key, value) in configDict) + { + if (SensitiveFields.Contains(key.ToLower())) + { + sanitized[key] = value != null ? MaskPassword(value.ToString() ?? string.Empty) : string.Empty; + } + else if (value is Dictionary nestedDict) + { + sanitized[key] = SanitizeConfigForDisplay(nestedDict); + } + else + { + sanitized[key] = value; + } + } + + return sanitized; + } + + public static string SanitizeLogMessage(string message) + { + var patterns = new Dictionary + { + [@"password\s*[:=]\s*['""]?([^'""\s,}]+)['""]?"] = "password=***", + [@"passwd\s*[:=]\s*['""]?([^'""\s,}]+)['""]?"] = "passwd=***", + [@"""password""\s*:\s*""[^""]*"""] = @"""password"": ""***""", + [@"'password'\s*:\s*'[^']*'"] = @"'password': '***'" + }; + + var sanitized = message; + foreach (var (pattern, replacement) in patterns) + { + sanitized = Regex.Replace(sanitized, pattern, replacement, RegexOptions.IgnoreCase); + } + + return sanitized; + } + + public static string CreateSafeConnectionString(string host, int port, string database, string username) + { + return $"{username}@{host}:{port}/{database}"; + } +} diff --git a/util/Seeder/Migration/Utils/SshTunnel.cs b/util/Seeder/Migration/Utils/SshTunnel.cs new file mode 100644 index 0000000000..6cb1fff74d --- /dev/null +++ b/util/Seeder/Migration/Utils/SshTunnel.cs @@ -0,0 +1,271 @@ +using Microsoft.Extensions.Logging; +using Renci.SshNet; +using System.Net.Sockets; + +namespace Bit.Seeder.Migration.Utils; + +public class SshTunnel( + string remoteHost, + string remoteUser, + int localPort, + int remotePort, + string? privateKeyPath, + string? privateKeyPassphrase, + ILogger logger) : IDisposable +{ + private readonly ILogger _logger = logger; + private readonly string _remoteHost = remoteHost; + private readonly string _remoteUser = remoteUser; + private readonly int _localPort = localPort; + private readonly int _remotePort = remotePort; + private readonly string? _privateKeyPath = privateKeyPath; + private readonly string? _privateKeyPassphrase = privateKeyPassphrase; + private SshClient? _sshClient; + private ForwardedPortLocal? _forwardedPort; + private bool _isConnected; + + public bool StartTunnel() + { + if (_isConnected) + { + _logger.LogWarning("SSH tunnel is already connected"); + return true; + } + + _logger.LogInformation("Starting SSH tunnel: {RemoteUser}@{RemoteHost}", _remoteUser, _remoteHost); + _logger.LogInformation("Port forwarding: localhost:{LocalPort} -> {RemoteHost}:{RemotePort}", _localPort, _remoteHost, _remotePort); + + try + { + // Create SSH client with authentication + if (!string.IsNullOrEmpty(_privateKeyPath)) + { + var keyPath = ExpandPath(_privateKeyPath); + if (File.Exists(keyPath)) + { + _logger.LogDebug("Using SSH private key: {KeyPath}", keyPath); + + PrivateKeyFile keyFile; + if (!string.IsNullOrEmpty(_privateKeyPassphrase)) + { + _logger.LogDebug("Using passphrase for encrypted private key"); + keyFile = new PrivateKeyFile(keyPath, _privateKeyPassphrase); + } + else + { + // Try without passphrase first + try + { + keyFile = new PrivateKeyFile(keyPath); + } + catch (Exception ex) when (ex.Message.Contains("passphrase")) + { + _logger.LogInformation("SSH private key is encrypted. Please enter passphrase:"); + var passphrase = ReadPassword(); + if (string.IsNullOrEmpty(passphrase)) + { + throw new Exception("SSH private key requires a passphrase but none was provided"); + } + keyFile = new PrivateKeyFile(keyPath, passphrase); + } + } + + _sshClient = new SshClient(_remoteHost, _remoteUser, keyFile); + } + else + { + _logger.LogWarning("SSH private key not found: {KeyPath}, trying password authentication", keyPath); + _sshClient = new SshClient(_remoteHost, _remoteUser, string.Empty); + } + } + else + { + _logger.LogInformation("No SSH key specified, using keyboard-interactive authentication"); + _sshClient = new SshClient(_remoteHost, _remoteUser, string.Empty); + } + + // Configure SSH client + _sshClient.ConnectionInfo.Timeout = TimeSpan.FromSeconds(30); + _sshClient.KeepAliveInterval = TimeSpan.FromSeconds(30); + + // Connect SSH client + _logger.LogInformation("Connecting to SSH server..."); + _sshClient.Connect(); + + if (!_sshClient.IsConnected) + { + _logger.LogError("SSH connection failed"); + return false; + } + + _logger.LogInformation("SSH connection established"); + + // Create port forwarding + _forwardedPort = new ForwardedPortLocal("localhost", (uint)_localPort, "localhost", (uint)_remotePort); + _sshClient.AddForwardedPort(_forwardedPort); + + // Start port forwarding + _logger.LogInformation("Starting port forwarding..."); + _forwardedPort.Start(); + + // Wait a moment for tunnel to establish + Thread.Sleep(2000); + + // Test tunnel connectivity + if (TestTunnelConnectivity()) + { + _isConnected = true; + _logger.LogInformation("SSH tunnel established successfully"); + return true; + } + + _logger.LogError("SSH tunnel started but port is not accessible"); + StopTunnel(); + return false; + } + catch (Exception ex) + { + _logger.LogError("Error starting SSH tunnel: {Message}", ex.Message); + StopTunnel(); + return false; + } + } + + public void StopTunnel() + { + try + { + if (_forwardedPort != null) + { + _logger.LogInformation("Stopping SSH tunnel..."); + + if (_forwardedPort.IsStarted) + { + _forwardedPort.Stop(); + } + + _forwardedPort.Dispose(); + _forwardedPort = null; + } + + if (_sshClient != null) + { + if (_sshClient.IsConnected) + { + _sshClient.Disconnect(); + } + + _sshClient.Dispose(); + _sshClient = null; + } + + _isConnected = false; + _logger.LogInformation("SSH tunnel stopped"); + } + catch (Exception ex) + { + _logger.LogWarning("Error stopping SSH tunnel: {Message}", ex.Message); + } + } + + public bool IsTunnelActive() + { + if (!_isConnected || _sshClient == null || _forwardedPort == null) + return false; + + if (!_sshClient.IsConnected || !_forwardedPort.IsStarted) + { + _logger.LogWarning("SSH tunnel process has terminated"); + _isConnected = false; + return false; + } + + if (!TestTunnelConnectivity()) + { + _logger.LogWarning("SSH tunnel process running but port not accessible"); + return false; + } + + return true; + } + + private bool TestTunnelConnectivity() + { + try + { + using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReceiveTimeout, 5000); + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.SendTimeout, 5000); + + var result = socket.BeginConnect("localhost", _localPort, null, null); + var success = result.AsyncWaitHandle.WaitOne(5000, true); + + if (success) + { + socket.EndConnect(result); + _logger.LogDebug("Tunnel port {LocalPort} is accessible", _localPort); + return true; + } + + _logger.LogDebug("Tunnel port {LocalPort} connection timeout", _localPort); + return false; + } + catch (Exception ex) + { + _logger.LogDebug("Error testing tunnel connectivity: {Message}", ex.Message); + return false; + } + } + + public Dictionary GetConnectionInfo() => new() + { + ["remote_host"] = _remoteHost, + ["remote_user"] = _remoteUser, + ["local_port"] = _localPort, + ["remote_port"] = _remotePort, + ["is_connected"] = _isConnected, + ["client_connected"] = _sshClient?.IsConnected ?? false, + ["port_forwarding_active"] = _forwardedPort?.IsStarted ?? false + }; + + private static string ExpandPath(string path) + { + if (path.StartsWith("~/")) + { + var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + return Path.Combine(home, path[2..]); + } + return path; + } + + private static string ReadPassword() + { + var password = string.Empty; + ConsoleKeyInfo key; + + do + { + key = Console.ReadKey(intercept: true); + + if (key.Key != ConsoleKey.Backspace && key.Key != ConsoleKey.Enter) + { + password += key.KeyChar; + Console.Write("*"); + } + else if (key.Key == ConsoleKey.Backspace && password.Length > 0) + { + password = password[0..^1]; + Console.Write("\b \b"); + } + } + while (key.Key != ConsoleKey.Enter); + + Console.WriteLine(); + return password; + } + + public void Dispose() + { + StopTunnel(); + } +} diff --git a/util/Seeder/Recipes/CsvMigrationRecipe.cs b/util/Seeder/Recipes/CsvMigrationRecipe.cs new file mode 100644 index 0000000000..7f30851ed7 --- /dev/null +++ b/util/Seeder/Recipes/CsvMigrationRecipe.cs @@ -0,0 +1,545 @@ +using Bit.Seeder.Migration; +using Bit.Seeder.Migration.Databases; +using Bit.Seeder.Migration.Models; +using Bit.Seeder.Migration.Reporters; +using Bit.Seeder.Migration.Utils; +using Microsoft.Extensions.Logging; + +namespace Bit.Seeder.Recipes; + +public class CsvMigrationRecipe(MigrationConfig config, ILoggerFactory loggerFactory) +{ + private readonly ILogger _logger = loggerFactory.CreateLogger(); + private readonly MigrationConfig _config = config; + private readonly ILoggerFactory _loggerFactory = loggerFactory; + private readonly SchemaMapper _schemaMapper = new( + config.TableMappings, + config.SpecialColumns, + loggerFactory.CreateLogger()); + private readonly CsvHandler _csvHandler = new( + config.CsvSettings, + loggerFactory.CreateLogger()); + private SshTunnel? _sshTunnel; + private SqlServerExporter? _sourceExporter; + + // Separator constants for logging + private const string Separator = "================================================================================"; + private const string ShortSeparator = "----------------------------------------"; + + public bool StartSshTunnel(bool force = false) + { + if (!force && !_config.SshTunnel.Enabled) + { + _logger.LogInformation("SSH tunnel not enabled in configuration"); + return true; + } + + try + { + _logger.LogInformation("Starting SSH tunnel to {RemoteHost}...", _config.SshTunnel.RemoteHost); + _sshTunnel = new SshTunnel( + _config.SshTunnel.RemoteHost, + _config.SshTunnel.RemoteUser, + _config.SshTunnel.LocalPort, + _config.SshTunnel.RemotePort, + _config.SshTunnel.PrivateKeyPath, + _config.SshTunnel.PrivateKeyPassphrase, + _loggerFactory.CreateLogger()); + + return _sshTunnel.StartTunnel(); + } + catch (Exception ex) + { + _logger.LogError("Failed to start SSH tunnel: {Message}", ex.Message); + return false; + } + } + + public void StopSshTunnel() + { + if (_sshTunnel != null) + { + _sshTunnel.StopTunnel(); + _sshTunnel.Dispose(); + _sshTunnel = null; + } + } + + public bool DiscoverAndAnalyzeTables() + { + if (_config.Source == null) + { + _logger.LogError("Source database not configured"); + return false; + } + + try + { + _sourceExporter = new SqlServerExporter( + _config.Source, + _loggerFactory.CreateLogger()); + + if (!_sourceExporter.Connect()) + { + _logger.LogError("Failed to connect to source database"); + return false; + } + + var tables = _sourceExporter.DiscoverTables(); + _logger.LogInformation("\nDiscovered {Count} tables:", tables.Count); + + var patterns = _schemaMapper.DetectNamingPatterns(tables); + var suggestions = _schemaMapper.SuggestTableMappings(tables); + + _logger.LogInformation("\nTable Details:"); + _logger.LogInformation(Separator); + _logger.LogInformation("{Header1,-30} {Header2,10} {Header3,15} {Header4,15}", "Table Name", "Columns", "Rows", "Special Cols"); + _logger.LogInformation(Separator); + + foreach (var tableName in tables.OrderBy(t => t)) + { + var tableInfo = _sourceExporter.GetTableInfo(tableName); + var jsonColumns = _sourceExporter.IdentifyJsonColumns(tableName, 100); + + _logger.LogInformation("{TableName,-30} {ColumnCount,10} {RowCount,15:N0} {JsonColumnCount,15}", tableName, tableInfo.Columns.Count, tableInfo.RowCount, jsonColumns.Count); + + if (jsonColumns.Count > 0) + { + _logger.LogInformation(" → JSON columns: {JsonColumns}", string.Join(", ", jsonColumns)); + } + } + + _logger.LogInformation(Separator); + + _sourceExporter.Disconnect(); + return true; + } + catch (Exception ex) + { + _logger.LogError("Error during table discovery: {Message}", ex.Message); + return false; + } + } + + public bool ExportAllTables(TableFilter? tableFilter = null) + { + if (_config.Source == null) + { + _logger.LogError("Source database not configured"); + return false; + } + + try + { + _sourceExporter = new SqlServerExporter( + _config.Source, + _loggerFactory.CreateLogger()); + + if (!_sourceExporter.Connect()) + { + _logger.LogError("Failed to connect to source database"); + return false; + } + + var reporter = new ExportReporter(_loggerFactory.CreateLogger()); + var allTables = _sourceExporter.DiscoverTables(); + + TableFilter effectiveFilter = tableFilter != null + ? new TableFilter( + tableFilter.GetIncludeTables(), + tableFilter.GetExcludeTables(), + _config.ExcludeTables, + _loggerFactory.CreateLogger()) + : new TableFilter( + null, + null, + _config.ExcludeTables, + _loggerFactory.CreateLogger()); + + var tablesToExport = effectiveFilter.FilterTableList(allTables); + + reporter.StartExport(); + _logger.LogInformation("Exporting {Count} tables to CSV\n", tablesToExport.Count); + + foreach (var tableName in tablesToExport) + { + reporter.StartTable(tableName); + + try + { + var (columns, data) = _sourceExporter.ExportTableData(tableName, _config.BatchSize); + var specialColumns = _sourceExporter.IdentifyJsonColumns(tableName); + var csvPath = _csvHandler.ExportTableToCsv(tableName, columns, data.ToList(), specialColumns); + + if (_csvHandler.ValidateExport(data.Count, csvPath)) + { + reporter.FinishTable(ExportStatus.Success, data.Count); + } + else + { + reporter.FinishTable(ExportStatus.Failed, 0, "Export validation failed"); + } + } + catch (Exception ex) + { + reporter.FinishTable(ExportStatus.Failed, 0, ex.Message); + } + } + + reporter.FinishExport(); + _sourceExporter.Disconnect(); + return reporter.GetSummaryStats().FailedTables == 0; + } + catch (Exception ex) + { + _logger.LogError("Error during export: {Message}", ex.Message); + return false; + } + } + + public bool ImportToDatabase( + string dbType, + bool createTables = false, + bool clearExisting = false, + TableFilter? tableFilter = null, + int? batchSize = null) + { + try + { + if (!_config.Destinations.TryGetValue(dbType, out var destConfig)) + { + _logger.LogError("Database type '{DbType}' not found in configuration", dbType); + return false; + } + + dynamic? importer = CreateImporter(dbType, destConfig); + if (importer == null) + { + _logger.LogError("Failed to create importer for {DbType}", dbType); + return false; + } + + if (!importer.Connect()) + { + _logger.LogError("Failed to connect to {DbType} database", dbType); + return false; + } + + var reporter = new ImportReporter(_loggerFactory.CreateLogger()); + reporter.StartImport(); + + importer.DisableForeignKeys(); + + var csvFiles = Directory.GetFiles(_config.CsvSettings.OutputDir, "*.csv"); + var tableNames = csvFiles.Select(f => Path.GetFileNameWithoutExtension(f)) + .OrderBy(name => name) + .ToList(); + + TableFilter effectiveFilter = tableFilter != null + ? new TableFilter( + tableFilter.GetIncludeTables(), + tableFilter.GetExcludeTables(), + _config.ExcludeTables, + _loggerFactory.CreateLogger()) + : new TableFilter( + null, + null, + _config.ExcludeTables, + _loggerFactory.CreateLogger()); + + var tablesToImport = effectiveFilter.FilterTableList(tableNames); + _logger.LogInformation("\nImporting {Count} tables to {DbType}", tablesToImport.Count, dbType); + + foreach (var tableName in tablesToImport) + { + var csvPath = Path.Combine(_config.CsvSettings.OutputDir, $"{tableName}.csv"); + + if (!File.Exists(csvPath)) + { + _logger.LogWarning("CSV file not found for table {TableName}, skipping", tableName); + continue; + } + + try + { + var (columns, data) = _csvHandler.ImportCsvToData( + csvPath, + _schemaMapper.GetSpecialColumnsForTable(tableName)); + + var destTableName = _schemaMapper.GetDestinationTableName(tableName, dbType); + reporter.StartTable(tableName, destTableName, data.Count); + + var tableExists = importer.TableExists(destTableName); + + if (!tableExists && !createTables) + { + reporter.FinishTable(ImportStatus.Skipped, 0, + errorMessage: "Table does not exist and --create-tables not specified"); + continue; + } + + if (clearExisting && tableExists) + { + _logger.LogInformation("Clearing existing data from {DestTableName}", destTableName); + importer.DropTable(destTableName); + tableExists = false; + } + + if (!tableExists && createTables) + { + var tableInfo = CreateBasicTableInfo(tableName, columns, data); + var specialColumns = _schemaMapper.GetSpecialColumnsForTable(tableName); + + if (!importer.CreateTableFromSchema( + destTableName, + tableInfo.Columns, + tableInfo.ColumnTypes, + specialColumns)) + { + reporter.FinishTable(ImportStatus.Failed, 0, + errorMessage: "Failed to create table"); + continue; + } + } + + var effectiveBatchSize = batchSize ?? _config.BatchSize; + var success = importer.ImportData(destTableName, columns, data, effectiveBatchSize); + + if (success) + { + var actualCount = importer.GetTableRowCount(destTableName); + reporter.FinishTable(ImportStatus.Success, actualCount); + } + else + { + reporter.FinishTable(ImportStatus.Failed, 0, + errorMessage: "Import operation failed"); + } + } + catch (Exception ex) + { + _logger.LogError("Error importing {TableName}: {Message}", tableName, ex.Message); + reporter.FinishTable(ImportStatus.Failed, 0, errorMessage: ex.Message); + } + } + + importer.EnableForeignKeys(); + reporter.FinishImport(); + + var logsDir = "logs"; + Directory.CreateDirectory(logsDir); + var reportPath = Path.Combine(logsDir, + $"import_report_{dbType}_{DateTime.Now:yyyyMMdd_HHmmss}.txt"); + reporter.ExportReport(reportPath); + + importer.Disconnect(); + + var summary = reporter.GetSummaryStats(); + return summary.FailedTables == 0; + } + catch (Exception ex) + { + _logger.LogError("Error during import: {Message}", ex.Message); + return false; + } + } + + public bool VerifyImport(string dbType, TableFilter? tableFilter = null) + { + try + { + if (!_config.Destinations.TryGetValue(dbType, out var destConfig)) + { + _logger.LogError("Database type '{DbType}' not found in configuration", dbType); + return false; + } + + dynamic? importer = CreateImporter(dbType, destConfig); + if (importer == null) + { + _logger.LogError("Failed to create importer for {DbType}", dbType); + return false; + } + + if (!importer.Connect()) + { + _logger.LogError("Failed to connect to {DbType} database", dbType); + return false; + } + + var reporter = new VerificationReporter(_loggerFactory.CreateLogger()); + reporter.StartVerification(); + + var csvFiles = Directory.GetFiles(_config.CsvSettings.OutputDir, "*.csv"); + var tableNames = csvFiles.Select(f => Path.GetFileNameWithoutExtension(f)) + .OrderBy(name => name) + .ToList(); + + TableFilter effectiveFilter = tableFilter != null + ? new TableFilter( + tableFilter.GetIncludeTables(), + tableFilter.GetExcludeTables(), + _config.ExcludeTables, + _loggerFactory.CreateLogger()) + : new TableFilter( + null, + null, + _config.ExcludeTables, + _loggerFactory.CreateLogger()); + + var tablesToVerify = effectiveFilter.FilterTableList(tableNames); + _logger.LogInformation("\nVerifying {Count} tables in {DbType}", tablesToVerify.Count, dbType); + + foreach (var tableName in tablesToVerify) + { + var csvPath = Path.Combine(_config.CsvSettings.OutputDir, $"{tableName}.csv"); + + if (!File.Exists(csvPath)) + { + reporter.VerifyTable(tableName, tableName, -1, 0, + errorMessage: "CSV file not found"); + continue; + } + + try + { + var csvRowCount = CountCsvRows(csvPath); + var destTableName = _schemaMapper.GetDestinationTableName(tableName, dbType); + + if (!importer.TableExists(destTableName)) + { + reporter.VerifyTable(tableName, destTableName, csvRowCount, 0, + errorMessage: "Table does not exist in database"); + continue; + } + + var dbRowCount = importer.GetTableRowCount(destTableName); + reporter.VerifyTable(tableName, destTableName, csvRowCount, dbRowCount); + } + catch (Exception ex) + { + _logger.LogError("Error verifying {TableName}: {Message}", tableName, ex.Message); + reporter.VerifyTable(tableName, tableName, -1, 0, errorMessage: ex.Message); + } + } + + reporter.FinishVerification(); + + var logsDir = "logs"; + Directory.CreateDirectory(logsDir); + var reportPath = Path.Combine(logsDir, + $"verification_report_{dbType}_{DateTime.Now:yyyyMMdd_HHmmss}.txt"); + reporter.ExportReport(reportPath); + + importer.Disconnect(); + + var summary = reporter.GetSummaryStats(); + return summary.MismatchedTables == 0 && summary.ErrorTables == 0 && summary.MissingTables == 0; + } + catch (Exception ex) + { + _logger.LogError("Error during verification: {Message}", ex.Message); + return false; + } + } + + public bool TestConnection(string dbType) + { + try + { + if (!_config.Destinations.TryGetValue(dbType, out var destConfig)) + { + _logger.LogError("Database type '{DbType}' not found in configuration", dbType); + return false; + } + + dynamic? importer = CreateImporter(dbType, destConfig); + if (importer == null) + { + _logger.LogError("Failed to create importer for {DbType}", dbType); + return false; + } + + _logger.LogInformation("Testing connection to {DbType}...", dbType); + var result = importer.TestConnection(); + + if (result) + { + _logger.LogInformation("✓ Connection to {DbType} successful!", dbType); + } + else + { + _logger.LogError("✗ Connection to {DbType} failed", dbType); + } + + return result; + } + catch (Exception ex) + { + _logger.LogError("Connection test failed: {Message}", ex.Message); + return false; + } + } + + private dynamic? CreateImporter(string dbType, DatabaseConfig config) => + dbType.ToLower() switch + { + "postgres" or "postgresql" => new PostgresImporter(config, _loggerFactory.CreateLogger()), + "mariadb" or "mysql" => new MariaDbImporter(config, _loggerFactory.CreateLogger()), + "sqlite" => new SqliteImporter(config, _loggerFactory.CreateLogger()), + "sqlserver" or "mssql" => new SqlServerImporter(config, _loggerFactory.CreateLogger()), + _ => null + }; + + private static TableInfo CreateBasicTableInfo(string tableName, List columns, List data) + { + var columnTypes = new Dictionary(); + + for (int i = 0; i < columns.Count; i++) + { + var columnName = columns[i]; + var sampleValue = data.FirstOrDefault()?[i]; + + var inferredType = sampleValue switch + { + null => "NVARCHAR(MAX)", + int => "INT", + long => "BIGINT", + double or float or decimal => "DECIMAL(18,6)", + bool => "BIT", + DateTime => "DATETIME2", + byte[] => "VARBINARY(MAX)", + _ => "NVARCHAR(MAX)" + }; + + columnTypes[columnName] = inferredType + " NULL"; + } + + return new TableInfo + { + Name = tableName, + Columns = columns, + ColumnTypes = columnTypes, + RowCount = data.Count + }; + } + + private int CountCsvRows(string csvPath) + { + var lineCount = 0; + using (var reader = new StreamReader(csvPath)) + { + while (reader.ReadLine() != null) + { + lineCount++; + } + } + + if (_config.CsvSettings.IncludeHeaders) + { + lineCount--; + } + + return lineCount; + } +} diff --git a/util/Seeder/Seeder.csproj b/util/Seeder/Seeder.csproj index 4d7fbab767..0b3b39ff67 100644 --- a/util/Seeder/Seeder.csproj +++ b/util/Seeder/Seeder.csproj @@ -16,6 +16,15 @@ + + + + + + + + +