diff --git a/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs b/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs
index 81a94f0f7c..4c5d70340f 100644
--- a/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs
+++ b/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs
@@ -1,6 +1,7 @@
#nullable enable
using System.Data;
+using Bit.Core;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
@@ -8,6 +9,7 @@ using Bit.Core.Tools.Repositories;
using Bit.Infrastructure.Dapper.Repositories;
using Bit.Infrastructure.Dapper.Tools.Helpers;
using Dapper;
+using Microsoft.AspNetCore.DataProtection;
using Microsoft.Data.SqlClient;
namespace Bit.Infrastructure.Dapper.Tools.Repositories;
@@ -15,13 +17,24 @@ namespace Bit.Infrastructure.Dapper.Tools.Repositories;
///
public class SendRepository : Repository, ISendRepository
{
- public SendRepository(GlobalSettings globalSettings)
- : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString)
+ private readonly IDataProtector _dataProtector;
+
+ public SendRepository(GlobalSettings globalSettings, IDataProtectionProvider dataProtectionProvider)
+ : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString, dataProtectionProvider)
{ }
- public SendRepository(string connectionString, string readOnlyConnectionString)
+ public SendRepository(string connectionString, string readOnlyConnectionString, IDataProtectionProvider dataProtectionProvider)
: base(connectionString, readOnlyConnectionString)
- { }
+ {
+ _dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose);
+ }
+
+ public override async Task GetByIdAsync(Guid id)
+ {
+ var send = await base.GetByIdAsync(id);
+ UnprotectData(send);
+ return send;
+ }
///
public async Task> GetManyByUserIdAsync(Guid userId)
@@ -33,7 +46,9 @@ public class SendRepository : Repository, ISendRepository
new { UserId = userId },
commandType: CommandType.StoredProcedure);
- return results.ToList();
+ var sends = results.ToList();
+ UnprotectData(sends);
+ return sends;
}
}
@@ -47,15 +62,35 @@ public class SendRepository : Repository, ISendRepository
new { DeletionDate = deletionDateBefore },
commandType: CommandType.StoredProcedure);
- return results.ToList();
+ var sends = results.ToList();
+ UnprotectData(sends);
+ return sends;
}
}
+ public override async Task CreateAsync(Send send)
+ {
+ await ProtectDataAndSaveAsync(send, async () => await base.CreateAsync(send));
+ return send;
+ }
+
+ public override async Task ReplaceAsync(Send send)
+ {
+ await ProtectDataAndSaveAsync(send, async () => await base.ReplaceAsync(send));
+ }
+
///
public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid userId, IEnumerable sends)
{
return async (connection, transaction) =>
{
+ // Protect all sends before bulk update
+ var sendsList = sends.ToList();
+ foreach (var send in sendsList)
+ {
+ ProtectData(send);
+ }
+
// Create temp table
var sqlCreateTemp = @"
SELECT TOP 0 *
@@ -71,7 +106,7 @@ public class SendRepository : Repository, ISendRepository
using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction))
{
bulkCopy.DestinationTableName = "#TempSend";
- var sendsTable = sends.ToDataTable();
+ var sendsTable = sendsList.ToDataTable();
foreach (DataColumn col in sendsTable.Columns)
{
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
@@ -101,6 +136,69 @@ public class SendRepository : Repository, ISendRepository
cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId;
cmd.ExecuteNonQuery();
}
+
+ // Unprotect after save
+ foreach (var send in sendsList)
+ {
+ UnprotectData(send);
+ }
};
}
+
+ private async Task ProtectDataAndSaveAsync(Send send, Func saveTask)
+ {
+ if (send == null)
+ {
+ await saveTask();
+ return;
+ }
+
+ // Capture original value
+ var originalEmailHashes = send.EmailHashes;
+
+ // Protect value
+ ProtectData(send);
+
+ // Save
+ await saveTask();
+
+ // Restore original value
+ send.EmailHashes = originalEmailHashes;
+ }
+
+ private void ProtectData(Send send)
+ {
+ if (!send.EmailHashes?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
+ {
+ send.EmailHashes = string.Concat(Constants.DatabaseFieldProtectedPrefix,
+ _dataProtector.Protect(send.EmailHashes!));
+ }
+ }
+
+ private void UnprotectData(Send? send)
+ {
+ if (send == null)
+ {
+ return;
+ }
+
+ if (send.EmailHashes?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false)
+ {
+ send.EmailHashes = _dataProtector.Unprotect(
+ send.EmailHashes.Substring(Constants.DatabaseFieldProtectedPrefix.Length));
+ }
+ }
+
+ private void UnprotectData(IEnumerable sends)
+ {
+ if (sends == null)
+ {
+ return;
+ }
+
+ foreach (var send in sends)
+ {
+ UnprotectData(send);
+ }
+ }
}
diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
index a0ee0376c0..3f638f88e5 100644
--- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
+++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
@@ -119,6 +119,7 @@ public class DatabaseContext : DbContext
var eOrganizationDomain = builder.Entity();
var aWebAuthnCredential = builder.Entity();
var eOrganizationMemberBaseDetail = builder.Entity();
+ var eSend = builder.Entity();
// Shadow property configurations go here
@@ -148,6 +149,7 @@ public class DatabaseContext : DbContext
var dataProtectionConverter = new DataProtectionConverter(dataProtector);
eUser.Property(c => c.Key).HasConversion(dataProtectionConverter);
eUser.Property(c => c.MasterPassword).HasConversion(dataProtectionConverter);
+ eSend.Property(c => c.EmailHashes).HasConversion(dataProtectionConverter);
if (Database.IsNpgsql())
{