using Npgsql;
using NpgsqlTypes;
using Bit.Seeder.Migration.Models;
using Bit.Seeder.Migration.Utils;
using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
///
/// PostgreSQL database importer that handles schema creation, data import, and constraint management.
///
public class PostgresImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter
{
private readonly ILogger _logger = logger;
private readonly string _host = config.Host;
private readonly int _port = config.Port > 0 ? config.Port : 5432;
private readonly string _database = config.Database;
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private NpgsqlConnection? _connection;
private bool _disposed = false;
///
/// Connects to the PostgreSQL database.
///
public bool Connect()
{
try
{
var safeConnectionString = $"Host={_host};Port={_port};Database={_database};" +
$"Username={_username};Password={DbSeederConstants.REDACTED_PASSWORD};" +
$"Timeout={DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT};CommandTimeout={DbSeederConstants.DEFAULT_COMMAND_TIMEOUT};";
var actualConnectionString = safeConnectionString.Replace(DbSeederConstants.REDACTED_PASSWORD, _password);
_connection = new NpgsqlConnection(actualConnectionString);
_connection.Open();
_logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database);
return true;
}
catch (Exception ex)
{
_logger.LogError("Failed to connect to PostgreSQL: {Message}", ex.Message);
return false;
}
}
public void Disconnect()
{
if (_connection != null)
{
_connection.Close();
_connection.Dispose();
_connection = null;
_logger.LogInformation("Disconnected from PostgreSQL");
}
}
///
/// Creates a table from the provided schema definition.
///
public bool CreateTableFromSchema(
string tableName,
List columns,
Dictionary columnTypes,
List? specialColumns = null)
{
specialColumns ??= [];
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Convert SQL Server types to PostgreSQL types
var pgColumns = new List();
foreach (var colName in columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName));
pgColumns.Add($"\"{colName}\" {pgType}");
}
// Create tables with quoted identifiers to preserve case
var createSql = $@"
CREATE TABLE IF NOT EXISTS ""{tableName}"" (
{string.Join(",\n ", pgColumns)}
)";
_logger.LogInformation("Creating table {TableName} in PostgreSQL", tableName);
_logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql);
using var command = new NpgsqlCommand(createSql, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Successfully created table {TableName}", tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message);
return false;
}
}
private string? GetActualTableName(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var query = @"
SELECT table_name
FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'
LIMIT 1";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
using var reader = command.ExecuteReader();
if (reader.Read())
{
var actualTableName = reader.GetString(0);
// Validate table name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(actualTableName, "table name");
return actualTableName;
}
return null;
}
catch (Exception ex)
{
_logger.LogError("Error getting actual table name for {TableName}: {Message}", tableName, ex.Message);
return null;
}
}
public List GetTableColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var query = @"
SELECT column_name
FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'
ORDER BY ordinal_position";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
var columns = new List();
using var reader = command.ExecuteReader();
while (reader.Read())
{
var colName = reader.GetString(0);
// Validate column name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
columns.Add(colName);
}
return columns;
}
catch (Exception ex)
{
_logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message);
return [];
}
}
private Dictionary GetTableColumnTypes(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase);
var query = @"
SELECT column_name, data_type
FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
using var reader = command.ExecuteReader();
while (reader.Read())
{
var colName = reader.GetString(0);
// Validate column name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
columnTypes[colName] = reader.GetString(1);
}
return columnTypes;
}
catch (Exception ex)
{
_logger.LogError("Error getting column types for table {TableName}: {Message}", tableName, ex.Message);
return new Dictionary();
}
}
///
/// Imports data into a table using batch INSERT statements.
///
public bool ImportData(
string tableName,
List columns,
List