1
0
mirror of https://github.com/bitwarden/server synced 2025-12-24 12:13:17 +00:00
Files
server/util/Seeder/Migration/Databases/PostgresImporter.cs
2025-11-10 16:46:44 -08:00

876 lines
32 KiB
C#

using Bit.Seeder.Migration.Models;
using Bit.Seeder.Migration.Utils;
using Microsoft.Extensions.Logging;
using Npgsql;
using NpgsqlTypes;
namespace Bit.Seeder.Migration.Databases;
/// <summary>
/// PostgreSQL database importer that handles schema creation, data import, and constraint management.
/// </summary>
public class PostgresImporter(DatabaseConfig config, ILogger<PostgresImporter> logger) : IDatabaseImporter
{
private readonly ILogger<PostgresImporter> _logger = logger;
private readonly string _host = config.Host;
private readonly int _port = config.Port > 0 ? config.Port : 5432;
private readonly string _database = config.Database;
private readonly string _username = config.Username;
private readonly string _password = config.Password;
private NpgsqlConnection? _connection;
private bool _disposed = false;
/// <summary>
/// Connects to the PostgreSQL database.
/// </summary>
public bool Connect()
{
try
{
// SECURITY: Use NpgsqlConnectionStringBuilder to safely construct connection string
var builder = new NpgsqlConnectionStringBuilder
{
Host = _host,
Port = _port,
Database = _database,
Username = _username,
Password = _password,
Timeout = DbSeederConstants.DEFAULT_CONNECTION_TIMEOUT,
CommandTimeout = DbSeederConstants.DEFAULT_COMMAND_TIMEOUT
};
_connection = new NpgsqlConnection(builder.ConnectionString);
_connection.Open();
_logger.LogInformation("Connected to PostgreSQL: {Host}:{Port}/{Database}", _host, _port, _database);
return true;
}
catch (Exception ex)
{
_logger.LogError("Failed to connect to PostgreSQL: {Message}", ex.Message);
return false;
}
}
public void Disconnect()
{
if (_connection != null)
{
_connection.Close();
_connection.Dispose();
_connection = null;
_logger.LogInformation("Disconnected from PostgreSQL");
}
}
/// <summary>
/// Creates a table from the provided schema definition.
/// </summary>
public bool CreateTableFromSchema(
string tableName,
List<string> columns,
Dictionary<string, string> columnTypes,
List<string>? specialColumns = null)
{
specialColumns ??= [];
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
// Convert SQL Server types to PostgreSQL types
var pgColumns = new List<string>();
foreach (var colName in columns)
{
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var pgType = ConvertSqlServerTypeToPostgreSQL(sqlServerType, specialColumns.Contains(colName));
pgColumns.Add($"\"{colName}\" {pgType}");
}
// Create tables with quoted identifiers to preserve case
var createSql = $@"
CREATE TABLE IF NOT EXISTS ""{tableName}"" (
{string.Join(",\n ", pgColumns)}
)";
_logger.LogInformation("Creating table {TableName} in PostgreSQL", tableName);
_logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql);
using var command = new NpgsqlCommand(createSql, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Successfully created table {TableName}", tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message);
return false;
}
}
private string? GetActualTableName(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var query = @"
SELECT table_name
FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'
LIMIT 1";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
using var reader = command.ExecuteReader();
if (reader.Read())
{
var actualTableName = reader.GetString(0);
// Validate table name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(actualTableName, "table name");
return actualTableName;
}
return null;
}
catch (Exception ex)
{
_logger.LogError("Error getting actual table name for {TableName}: {Message}", tableName, ex.Message);
return null;
}
}
public List<string> GetTableColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var query = @"
SELECT column_name
FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'
ORDER BY ordinal_position";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
var columns = new List<string>();
using var reader = command.ExecuteReader();
while (reader.Read())
{
var colName = reader.GetString(0);
// Validate column name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
columns.Add(colName);
}
return columns;
}
catch (Exception ex)
{
_logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message);
return [];
}
}
private Dictionary<string, string> GetTableColumnTypes(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var columnTypes = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
var query = @"
SELECT column_name, data_type
FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
using var reader = command.ExecuteReader();
while (reader.Read())
{
var colName = reader.GetString(0);
// Validate column name immediately to prevent second-order SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
columnTypes[colName] = reader.GetString(1);
}
return columnTypes;
}
catch (Exception ex)
{
_logger.LogError("Error getting column types for table {TableName}: {Message}", tableName, ex.Message);
return new Dictionary<string, string>();
}
}
/// <summary>
/// Imports data into a table using batch INSERT statements.
/// </summary>
public bool ImportData(
string tableName,
List<string> columns,
List<object[]> data,
int batchSize = DbSeederConstants.DEFAULT_BATCH_SIZE)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
return true;
}
try
{
// Get the actual table name with correct casing
var actualTableName = GetActualTableName(tableName);
if (actualTableName == null)
{
_logger.LogError("Table {TableName} not found in database", tableName);
return false;
}
IdentifierValidator.ValidateOrThrow(actualTableName, "actual table name");
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
return false;
}
// Get column types from the database
var columnTypes = GetTableColumnTypes(tableName);
// Filter columns - use case-insensitive comparison
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
var validColumnTypes = new List<string>();
// Create a case-insensitive lookup of actual columns
var actualColumnsLookup = actualColumns.ToDictionary(c => c, c => c, StringComparer.OrdinalIgnoreCase);
for (int i = 0; i < columns.Count; i++)
{
if (actualColumnsLookup.TryGetValue(columns[i], out var actualColumnName))
{
validColumnIndices.Add(i);
validColumns.Add(actualColumnName); // Use the actual column name from DB
validColumnTypes.Add(columnTypes.GetValueOrDefault(actualColumnName, "text"));
}
else
{
_logger.LogDebug("Column '{Column}' from CSV not found in table {TableName}", columns[i], tableName);
}
}
if (validColumns.Count == 0)
{
_logger.LogError("No valid columns found for table {TableName}", tableName);
_logger.LogError("CSV columns: {Columns}", string.Join(", ", columns));
_logger.LogError("Table columns: {Columns}", string.Join(", ", actualColumns));
return false;
}
var filteredData = data.Select(row =>
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
).ToList();
_logger.LogInformation("Importing {Count} rows into {TableName}", filteredData.Count, tableName);
// Build INSERT statement with explicit type casts for all types
var quotedColumns = validColumns.Select(col => $"\"{col}\"").ToList();
var placeholders = validColumns.Select((col, idx) =>
{
var paramNum = idx + 1;
var colType = validColumnTypes[idx];
// Cast to appropriate type if needed - PostgreSQL requires explicit casts for text to other types
// SECURITY: Use explicit allowlist to prevent potential SQL injection through compromised schema
return colType switch
{
// UUID types
"uuid" => $"${paramNum}::uuid",
// Timestamp types
"timestamp without time zone" => $"${paramNum}::timestamp",
"timestamp with time zone" => $"${paramNum}::timestamptz",
"date" => $"${paramNum}::date",
"time without time zone" => $"${paramNum}::time",
"time with time zone" => $"${paramNum}::timetz",
// Integer types
"smallint" => $"${paramNum}::smallint",
"integer" => $"${paramNum}::integer",
"bigint" => $"${paramNum}::bigint",
// Numeric types
"numeric" => $"${paramNum}::numeric",
"decimal" => $"${paramNum}::decimal",
"real" => $"${paramNum}::real",
"double precision" => $"${paramNum}::double precision",
// Boolean type
"boolean" => $"${paramNum}::boolean",
// Text and string types (no cast needed)
"text" => $"${paramNum}",
"character varying" => $"${paramNum}",
"varchar" => $"${paramNum}",
"character" => $"${paramNum}",
"char" => $"${paramNum}",
// Binary types
"bytea" => $"${paramNum}::bytea",
// JSON types
"json" => $"${paramNum}::json",
"jsonb" => $"${paramNum}::jsonb",
// Array types (common cases)
"text[]" => $"${paramNum}::text[]",
"integer[]" => $"${paramNum}::integer[]",
"uuid[]" => $"${paramNum}::uuid[]",
// Reject unknown types to prevent potential schema-based attacks
_ => throw new InvalidOperationException($"Unsupported PostgreSQL column type '{colType}' for column '{col}'. Type must be explicitly allowed in the PostgresImporter allowlist.")
};
});
var insertSql = $"INSERT INTO \"{actualTableName}\" ({string.Join(", ", quotedColumns)}) VALUES ({string.Join(", ", placeholders)})";
var totalImported = 0;
for (int i = 0; i < filteredData.Count; i += batchSize)
{
var batch = filteredData.Skip(i).Take(batchSize).ToList();
using var transaction = _connection.BeginTransaction();
try
{
foreach (var row in batch)
{
using var command = new NpgsqlCommand(insertSql, _connection, transaction);
var preparedRow = PrepareRowForInsert(row, validColumns);
for (int p = 0; p < preparedRow.Length; p++)
{
command.Parameters.AddWithValue(preparedRow[p] ?? DBNull.Value);
}
command.ExecuteNonQuery();
}
transaction.Commit();
totalImported += batch.Count;
if (filteredData.Count > DbSeederConstants.LOGGING_THRESHOLD)
{
_logger.LogDebug("Batch: {BatchCount} rows ({TotalImported}/{FilteredDataCount} total)", batch.Count, totalImported, filteredData.Count);
}
}
catch
{
transaction.SafeRollback(_connection, _logger, tableName);
throw;
}
}
_logger.LogInformation("Successfully imported {TotalImported} rows into {TableName}", totalImported, tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error importing data into {TableName}: {Message}", tableName, ex.Message);
_logger.LogError("Stack trace: {StackTrace}", ex.StackTrace);
if (ex.InnerException != null)
{
_logger.LogError("Inner exception: {Message}", ex.InnerException.Message);
}
return false;
}
}
/// <summary>
/// Checks if a table exists in the database.
/// </summary>
public bool TableExists(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = @"
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(@tableName) AND table_schema = 'public'
)";
using var command = new NpgsqlCommand(query, _connection);
command.Parameters.AddWithValue("tableName", tableName);
return command.GetScalarValue<bool>(false, _logger, $"table existence check for {tableName}");
}
catch (Exception ex)
{
_logger.LogError("Error checking if table {TableName} exists: {Message}", tableName, ex.Message);
return false;
}
}
public int GetTableRowCount(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var actualTableName = GetActualTableName(tableName);
if (actualTableName == null)
{
_logger.LogError("Table {TableName} not found in database", tableName);
return 0;
}
var query = $"SELECT COUNT(*) FROM \"{actualTableName}\"";
using var command = new NpgsqlCommand(query, _connection);
return Convert.ToInt32(command.ExecuteScalar());
}
catch (Exception ex)
{
_logger.LogError("Error getting row count for {TableName}: {Message}", tableName, ex.Message);
return 0;
}
}
public bool DropTable(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
var actualTableName = GetActualTableName(tableName);
if (actualTableName == null)
{
_logger.LogWarning("Table {TableName} not found, skipping drop", tableName);
return true;
}
var query = $"DROP TABLE IF EXISTS \"{actualTableName}\" CASCADE";
using var command = new NpgsqlCommand(query, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Dropped table {TableName}", tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error dropping table {TableName}: {Message}", tableName, ex.Message);
return false;
}
}
public bool DisableForeignKeys()
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
_logger.LogInformation("Disabling foreign key constraints");
var query = "SET session_replication_role = replica;";
using var command = new NpgsqlCommand(query, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Foreign key constraints deferred");
return true;
}
catch (Exception ex)
{
_logger.LogError("Error disabling foreign key constraints: {Message}", ex.Message);
return false;
}
}
public bool EnableForeignKeys()
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
try
{
_logger.LogInformation("Re-enabling foreign key constraints");
var query = "SET session_replication_role = DEFAULT;";
using var command = new NpgsqlCommand(query, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Foreign key constraints re-enabled");
return true;
}
catch (Exception ex)
{
_logger.LogError("Error re-enabling foreign key constraints: {Message}", ex.Message);
return false;
}
}
private string ConvertSqlServerTypeToPostgreSQL(string sqlServerType, bool isJsonColumn)
{
var baseType = sqlServerType.Replace(" NULL", "").Replace(" NOT NULL", "").Trim();
var isNullable = !sqlServerType.Contains("NOT NULL");
if (isJsonColumn)
return "TEXT" + (isNullable ? "" : " NOT NULL");
var pgType = baseType.ToUpper() switch
{
var t when t.StartsWith("VARCHAR") => t.Contains("MAX") ? "TEXT" : t.Replace("VARCHAR", "VARCHAR"),
var t when t.StartsWith("NVARCHAR") => "TEXT",
"INT" or "INTEGER" => "INTEGER",
"BIGINT" => "BIGINT",
"SMALLINT" => "SMALLINT",
"TINYINT" => "SMALLINT",
"BIT" => "BOOLEAN",
var t when t.StartsWith("DECIMAL") => t.Replace("DECIMAL", "DECIMAL"),
"FLOAT" => "DOUBLE PRECISION",
"REAL" => "REAL",
"DATETIME" or "DATETIME2" or "SMALLDATETIME" => "TIMESTAMP",
"DATE" => "DATE",
"TIME" => "TIME",
"DATETIMEOFFSET" => "TIMESTAMPTZ",
"UNIQUEIDENTIFIER" => "UUID",
var t when t.StartsWith("VARBINARY") => "BYTEA",
"XML" => "XML",
_ => "TEXT"
};
return pgType + (isNullable ? "" : " NOT NULL");
}
private object[] PrepareRowForInsert(object?[] row, List<string> columns)
{
return row.Select(value =>
{
if (value == null || value == DBNull.Value)
return DBNull.Value;
if (value is string strValue)
{
// Only convert truly empty strings to DBNull, not whitespace
// This preserves JSON strings and other data that might have whitespace
if (strValue.Length == 0)
return DBNull.Value;
if (strValue.Equals("true", StringComparison.OrdinalIgnoreCase))
return true;
if (strValue.Equals("false", StringComparison.OrdinalIgnoreCase))
return false;
}
return value;
}).ToArray();
}
public bool SupportsBulkCopy()
{
return true; // PostgreSQL COPY is highly optimized
}
public bool ImportDataBulk(
string tableName,
List<string> columns,
List<object[]> data)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
if (data.Count == 0)
{
_logger.LogWarning("No data to import for table {TableName}", tableName);
return true;
}
try
{
// Get the actual table name with correct casing
var actualTableName = GetActualTableName(tableName);
if (actualTableName == null)
{
_logger.LogError("Table {TableName} not found in database", tableName);
return false;
}
var actualColumns = GetTableColumns(tableName);
if (actualColumns.Count == 0)
{
_logger.LogError("Could not retrieve columns for table {TableName}", tableName);
return false;
}
// Get column types from the database
var columnTypes = GetTableColumnTypes(tableName);
// Filter columns - use case-insensitive comparison
var validColumnIndices = new List<int>();
var validColumns = new List<string>();
var validColumnTypes = new List<string>();
// Create a case-insensitive lookup of actual columns
var actualColumnsLookup = actualColumns.ToDictionary(c => c, c => c, StringComparer.OrdinalIgnoreCase);
for (int i = 0; i < columns.Count; i++)
{
if (actualColumnsLookup.TryGetValue(columns[i], out var actualColumnName))
{
validColumnIndices.Add(i);
validColumns.Add(actualColumnName);
validColumnTypes.Add(columnTypes.GetValueOrDefault(actualColumnName, "text"));
}
else
{
_logger.LogDebug("Column '{Column}' from CSV not found in table {TableName}", columns[i], tableName);
}
}
if (validColumns.Count == 0)
{
_logger.LogError("No valid columns found for table {TableName}", tableName);
return false;
}
var filteredData = data.Select(row =>
validColumnIndices.Select(i => i < row.Length ? row[i] : null).ToArray()
).ToList();
_logger.LogInformation("Bulk importing {Count} rows into {TableName} using PostgreSQL COPY", filteredData.Count, tableName);
// Use PostgreSQL's COPY command for binary import (fastest method)
var quotedColumns = validColumns.Select(col => $"\"{col}\"");
var copyCommand = $"COPY \"{actualTableName}\" ({string.Join(", ", quotedColumns)}) FROM STDIN (FORMAT BINARY)";
using var writer = _connection.BeginBinaryImport(copyCommand);
foreach (var row in filteredData)
{
writer.StartRow();
var preparedRow = PrepareRowForInsert(row, validColumns);
for (int i = 0; i < preparedRow.Length; i++)
{
var value = preparedRow[i];
if (value == null || value == DBNull.Value)
{
writer.WriteNull();
}
else
{
// Write with appropriate type based on column type
var colType = validColumnTypes[i];
WriteValueForCopy(writer, value, colType);
}
}
}
var rowsImported = writer.Complete();
_logger.LogInformation("Successfully bulk imported {RowsImported} rows into {TableName}", rowsImported, tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error during bulk import into {TableName}: {Message}", tableName, ex.Message);
_logger.LogError("Stack trace: {StackTrace}", ex.StackTrace);
if (ex.InnerException != null)
{
_logger.LogError("Inner exception: {Message}", ex.InnerException.Message);
}
return false;
}
}
private void WriteValueForCopy(Npgsql.NpgsqlBinaryImporter writer, object value, string columnType)
{
// Handle type-specific writing for PostgreSQL COPY
switch (columnType.ToLower())
{
case "uuid":
if (value is string strGuid && Guid.TryParse(strGuid, out var guid))
writer.Write(guid, NpgsqlDbType.Uuid);
else if (value is Guid g)
writer.Write(g, NpgsqlDbType.Uuid);
else
writer.Write(value.ToString()!, NpgsqlDbType.Uuid);
break;
case "boolean":
if (value is bool b)
writer.Write(b);
else if (value is string strBool)
writer.Write(strBool.Equals("true", StringComparison.OrdinalIgnoreCase) || strBool == "1");
else
writer.Write(Convert.ToBoolean(value));
break;
case "smallint":
writer.Write(Convert.ToInt16(value));
break;
case "integer":
writer.Write(Convert.ToInt32(value));
break;
case "bigint":
writer.Write(Convert.ToInt64(value));
break;
case "real":
writer.Write(Convert.ToSingle(value));
break;
case "double precision":
writer.Write(Convert.ToDouble(value));
break;
case "numeric":
case "decimal":
writer.Write(Convert.ToDecimal(value));
break;
case "timestamp without time zone":
case "timestamp":
if (value is DateTime dt)
{
// For timestamp without time zone, we can use the value as-is
// But if it's Unspecified, treat it as if it's in the local context
var timestampValue = dt.Kind == DateTimeKind.Unspecified
? DateTime.SpecifyKind(dt, DateTimeKind.Utc)
: dt;
writer.Write(timestampValue, NpgsqlDbType.Timestamp);
}
else if (value is string strDt && DateTime.TryParse(strDt, out var parsedDt))
{
var timestampValue = DateTime.SpecifyKind(parsedDt, DateTimeKind.Utc);
writer.Write(timestampValue, NpgsqlDbType.Timestamp);
}
else
writer.Write(value.ToString()!);
break;
case "timestamp with time zone":
case "timestamptz":
if (value is DateTime dtz)
{
// PostgreSQL timestamptz requires UTC DateTimes
var utcValue = dtz.Kind == DateTimeKind.Unspecified
? DateTime.SpecifyKind(dtz, DateTimeKind.Utc)
: dtz.Kind == DateTimeKind.Local
? dtz.ToUniversalTime()
: dtz;
writer.Write(utcValue, NpgsqlDbType.TimestampTz);
}
else if (value is string strDtz && DateTime.TryParse(strDtz, out var parsedDtz))
{
// Parsed DateTimes are Unspecified, treat as UTC
var utcValue = DateTime.SpecifyKind(parsedDtz, DateTimeKind.Utc);
writer.Write(utcValue, NpgsqlDbType.TimestampTz);
}
else
writer.Write(value.ToString()!);
break;
case "date":
if (value is DateTime date)
writer.Write(date, NpgsqlDbType.Date);
else if (value is string strDate && DateTime.TryParse(strDate, out var parsedDate))
writer.Write(parsedDate, NpgsqlDbType.Date);
else
writer.Write(value.ToString()!);
break;
case "bytea":
if (value is byte[] bytes)
writer.Write(bytes);
else
writer.Write(value.ToString()!);
break;
default:
// Text and all other types
writer.Write(value.ToString()!);
break;
}
}
/// <summary>
/// Tests the connection to PostgreSQL by executing a simple query.
/// </summary>
public bool TestConnection()
{
try
{
if (Connect())
{
using var command = new NpgsqlCommand("SELECT 1", _connection);
var result = command.GetScalarValue<int>(0, _logger, "connection test");
Disconnect();
return result == 1;
}
return false;
}
catch (Exception ex)
{
_logger.LogError("PostgreSQL connection test failed: {Message}", ex.Message);
return false;
}
}
/// <summary>
/// Disposes of the PostgreSQL importer and releases all resources.
/// </summary>
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// Implements Dispose pattern for resource cleanup.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
Disconnect();
}
_disposed = true;
}
}
}