using Bit.Seeder.Migration.Models;
using Bit.Seeder.Migration.Utils;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.Logging;
namespace Bit.Seeder.Migration.Databases;
///
/// SQLite database importer that handles schema creation and data import.
///
public class SqliteImporter(DatabaseConfig config, ILogger logger) : IDatabaseImporter
{
private readonly ILogger _logger = logger;
private readonly string _databasePath = config.Database;
private SqliteConnection? _connection;
private bool _disposed = false;
public bool Connect()
{
try
{
// Ensure directory exists
var directory = Path.GetDirectoryName(_databasePath);
if (!string.IsNullOrEmpty(directory))
{
Directory.CreateDirectory(directory);
}
var connectionString = $"Data Source={_databasePath}";
_connection = new SqliteConnection(connectionString);
_connection.Open();
// Enable foreign keys and set pragmas for better performance
using (var command = new SqliteCommand("PRAGMA foreign_keys = ON", _connection))
{
command.ExecuteNonQuery();
}
using (var command = new SqliteCommand("PRAGMA journal_mode = WAL", _connection))
{
command.ExecuteNonQuery();
}
using (var command = new SqliteCommand("PRAGMA synchronous = NORMAL", _connection))
{
command.ExecuteNonQuery();
}
_logger.LogInformation("Connected to SQLite database: {DatabasePath}", _databasePath);
return true;
}
catch (Exception ex)
{
_logger.LogError("Failed to connect to SQLite: {Message}", ex.Message);
return false;
}
}
public void Disconnect()
{
if (_connection != null)
{
try
{
// Force completion of any pending WAL operations
using (var command = new SqliteCommand("PRAGMA wal_checkpoint(TRUNCATE)", _connection))
{
command.ExecuteNonQuery();
}
}
catch (Exception ex)
{
_logger.LogWarning("Error during WAL checkpoint: {Message}", ex.Message);
}
_connection.Close();
_connection.Dispose();
_connection = null;
_logger.LogInformation("Disconnected from SQLite");
}
}
///
/// Creates a table from the provided schema definition.
///
public bool CreateTableFromSchema(
string tableName,
List columns,
Dictionary columnTypes,
List? specialColumns = null)
{
specialColumns ??= [];
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var sqliteColumns = new List();
foreach (var colName in columns)
{
// Validate each column name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(colName, "column name");
var sqlServerType = columnTypes.GetValueOrDefault(colName, "VARCHAR(MAX)");
var sqliteType = ConvertSqlServerTypeToSQLite(sqlServerType, specialColumns.Contains(colName));
sqliteColumns.Add($"\"{colName}\" {sqliteType}");
}
var createSql = $@"
CREATE TABLE IF NOT EXISTS ""{tableName}"" (
{string.Join(",\n ", sqliteColumns)}
)";
_logger.LogInformation("Creating table {TableName} in SQLite", tableName);
_logger.LogDebug("CREATE TABLE SQL: {CreateSql}", createSql);
using var command = new SqliteCommand(createSql, _connection);
command.ExecuteNonQuery();
_logger.LogInformation("Successfully created table {TableName}", tableName);
return true;
}
catch (Exception ex)
{
_logger.LogError("Error creating table {TableName}: {Message}", tableName, ex.Message);
return false;
}
}
///
/// Gets the list of columns for a table.
///
public List GetTableColumns(string tableName)
{
if (_connection == null)
throw new InvalidOperationException("Not connected to database");
// Validate table name to prevent SQL injection
IdentifierValidator.ValidateOrThrow(tableName, "table name");
try
{
var query = $"PRAGMA table_info(\"{tableName}\")";
using var command = new SqliteCommand(query, _connection);
using var reader = command.ExecuteReader();
var columns = new List();
while (reader.Read())
{
columns.Add(reader.GetString(1)); // Column name is at index 1
}
return columns;
}
catch (Exception ex)
{
_logger.LogError("Error getting columns for table {TableName}: {Message}", tableName, ex.Message);
return [];
}
}
///
/// Imports data into a table using batch INSERT statements.
///
public bool ImportData(
string tableName,
List columns,
List