From 1dade9d4b868fb73907c0d280fd19bb0191692cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:57:53 +0100 Subject: [PATCH] [PM-24233] Use BulkResourceCreationService in CipherRepository (#6201) * Add constant for CipherRepositoryBulkResourceCreation in FeatureFlagKeys * Add bulk creation methods for Ciphers, Folders, and CollectionCiphers in BulkResourceCreationService - Implemented CreateCiphersAsync, CreateFoldersAsync, CreateCollectionCiphersAsync, and CreateTempCiphersAsync methods for bulk insertion. - Added helper methods to build DataTables for Ciphers, Folders, and CollectionCiphers. - Enhanced error handling for empty collections during bulk operations. * Refactor CipherRepository to utilize BulkResourceCreationService - Introduced IFeatureService to manage feature flag checks for bulk operations. - Updated methods to conditionally use BulkResourceCreationService for creating Ciphers, Folders, and CollectionCiphers based on feature flag status. - Enhanced existing bulk copy logic to maintain functionality while integrating feature flag checks. * Add InlineFeatureService to DatabaseDataAttribute for feature flag management - Introduced EnabledFeatureFlags property to DatabaseDataAttribute for configuring feature flags. - Integrated InlineFeatureService to provide feature flag checks within the service collection. - Enhanced GetData method to utilize feature flags for conditional service registration. * Add tests for bulk creation of Ciphers in CipherRepositoryTests - Implemented tests for bulk creation of Ciphers, Folders, and Collections with feature flag checks. - Added test cases for updating multiple Ciphers to validate bulk update functionality. - Enhanced existing test structure to ensure comprehensive coverage of bulk operations in the CipherRepository. * Refactor BulkResourceCreationService to use dynamic types for DataColumns - Updated DataColumn definitions in BulkResourceCreationService to utilize the actual types of properties from the cipher object instead of hardcoded types. - Simplified the assignment of nullable properties to directly use their values, improving code readability and maintainability. * Update BulkResourceCreationService to use specific types for DataColumns - Changed DataColumn definitions to use specific types (short and string) instead of dynamic types based on cipher properties. - Improved handling of nullable properties when assigning values to DataTable rows, ensuring proper handling of DBNull for null values. * Refactor CipherRepositoryTests for improved clarity and consistency - Renamed test methods to better reflect their purpose and improve readability. - Updated test data to use more descriptive names for users, folders, and collections. - Enhanced test structure with clear Arrange, Act, and Assert sections for better understanding of test flow. - Ensured all tests validate the expected outcomes for bulk operations with feature flag checks. * Update CipherRepositoryBulkResourceCreation feature flag key * Refactor DatabaseDataAttribute usage in CipherRepositoryTests to use array syntax for EnabledFeatureFlags * Update CipherRepositoryTests to use GenerateComb for generating unique IDs * Refactor CipherRepository methods to accept a boolean parameter for enabling bulk resource creation based on feature flags. Update tests to verify functionality with and without the feature flag enabled. * Refactor CipherRepository and related services to support new methods for bulk resource creation without boolean parameters. --- src/Core/Constants.cs | 1 + .../RotateUserAccountkeysCommand.cs | 15 +- .../ImportFeatures/ImportCiphersCommand.cs | 20 +- .../Vault/Repositories/ICipherRepository.cs | 22 ++ .../Services/Implementations/CipherService.cs | 10 +- .../Helpers/BulkResourceCreationService.cs | 190 ++++++++++++++++ .../Vault/Repositories/CipherRepository.cs | 211 ++++++++++++++++++ .../Vault/Repositories/CipherRepository.cs | 41 ++++ .../ImportCiphersAsyncCommandTests.cs | 136 ++++++++++- .../Vault/Services/CipherServiceTests.cs | 53 +++++ .../Repositories/CipherRepositoryTests.cs | 157 +++++++++++++ 11 files changed, 849 insertions(+), 7 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 393ab15e4c..2993f6a094 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -114,6 +114,7 @@ public static class FeatureFlagKeys public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions"; public const string CreateDefaultLocation = "pm-19467-create-default-location"; public const string DirectoryConnectorPreventUserRemoval = "pm-24592-directory-connector-prevent-user-removal"; + public const string CipherRepositoryBulkResourceCreation = "pm-24951-cipher-repository-bulk-resource-creation-service"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index 6967c9bf85..011fc2932f 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -25,6 +25,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; private readonly IPasswordHasher _passwordHasher; + private readonly IFeatureService _featureService; /// /// Instantiates a new @@ -45,7 +46,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, - IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository) + IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository, + IFeatureService featureService) { _userService = userService; _userRepository = userRepository; @@ -59,6 +61,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _identityErrorDescriber = errors; _credentialRepository = credentialRepository; _passwordHasher = passwordHasher; + _featureService = featureService; } /// @@ -100,7 +103,15 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand List saveEncryptedDataActions = new(); if (model.Ciphers.Any()) { - saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation_vNext(user.Id, model.Ciphers)); + } + else + { + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); + } } if (model.Folders.Any()) diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index c7f7e3aff7..ce269bc68c 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -108,7 +108,15 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.CreateAsync_vNext(importingUserId, ciphers, newFolders); + } + else + { + await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); + } // push await _pushService.PushSyncVaultAsync(importingUserId); @@ -183,7 +191,15 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.CreateAsync_vNext(ciphers, newCollections, collectionCiphers, newCollectionUsers); + } + else + { + await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); + } // push await _pushService.PushSyncVaultAsync(importingUserId); diff --git a/src/Core/Vault/Repositories/ICipherRepository.cs b/src/Core/Vault/Repositories/ICipherRepository.cs index 5a04a6651d..60b6e21f1d 100644 --- a/src/Core/Vault/Repositories/ICipherRepository.cs +++ b/src/Core/Vault/Repositories/ICipherRepository.cs @@ -32,12 +32,28 @@ public interface ICipherRepository : IRepository Task DeleteByUserIdAsync(Guid userId); Task DeleteByOrganizationIdAsync(Guid organizationId); Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); + /// + /// + /// This version uses the bulk resource creation service to create the temp table. + /// + Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers); /// /// Create ciphers and folders for the specified UserId. Must not be used to create organization owned items. /// Task CreateAsync(Guid userId, IEnumerable ciphers, IEnumerable folders); + /// + /// + /// This version uses the bulk resource creation service to create the temp tables. + /// + Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders); Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, IEnumerable collectionUsers); + /// + /// + /// This version uses the bulk resource creation service to create the temp tables. + /// + Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers, IEnumerable collectionUsers); Task SoftDeleteAsync(IEnumerable ids, Guid userId); Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); Task RestoreAsync(IEnumerable ids, Guid userId); @@ -68,4 +84,10 @@ public interface ICipherRepository : IRepository /// A list of ciphers with updated data UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid userId, IEnumerable ciphers); + /// + /// + /// This version uses the bulk resource creation service to create the temp table. + /// + UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext(Guid userId, + IEnumerable ciphers); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index 51ed4b0ce7..2a4cc6c137 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -642,7 +642,15 @@ public class CipherService : ICipherService cipherIds.Add(cipher.Id); } - await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); + if (useBulkResourceCreationService) + { + await _cipherRepository.UpdateCiphersAsync_vNext(sharingUserId, cipherInfos.Select(c => c.cipher)); + } + else + { + await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + } await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, organizationId, collectionIds); diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs index 139960ceba..3610c1c484 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -1,5 +1,6 @@ using System.Data; using Bit.Core.Entities; +using Bit.Core.Vault.Entities; using Microsoft.Data.SqlClient; namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers; @@ -15,6 +16,38 @@ public static class BulkResourceCreationService await bulkCopy.WriteToServerAsync(dataTable); } + public static async Task CreateCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable ciphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateFoldersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable folders, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[Folder]"; + var dataTable = BuildFoldersTable(bulkCopy, folders, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateCollectionCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collectionCiphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; + var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + public static async Task CreateTempCiphersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable ciphers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + private static DataTable BuildCollectionsUsersTable(SqlBulkCopy bulkCopy, IEnumerable collectionUsers, string errorMessage) { var collectionUser = collectionUsers.FirstOrDefault(); @@ -126,4 +159,161 @@ public static class BulkResourceCreationService return collectionsTable; } + + private static DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers, string errorMessage) + { + var c = ciphers.FirstOrDefault(); + + if (c == null) + { + throw new ApplicationException(errorMessage); + } + + var ciphersTable = new DataTable("CipherDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + ciphersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); + ciphersTable.Columns.Add(userIdColumn); + var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); + ciphersTable.Columns.Add(organizationId); + var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); + ciphersTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); + ciphersTable.Columns.Add(dataColumn); + var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); + ciphersTable.Columns.Add(favoritesColumn); + var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); + ciphersTable.Columns.Add(foldersColumn); + var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); + ciphersTable.Columns.Add(attachmentsColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + ciphersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + ciphersTable.Columns.Add(revisionDateColumn); + var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); + ciphersTable.Columns.Add(deletedDateColumn); + var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); + ciphersTable.Columns.Add(repromptColumn); + var keyColummn = new DataColumn(nameof(c.Key), typeof(string)); + ciphersTable.Columns.Add(keyColummn); + + foreach (DataColumn col in ciphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + ciphersTable.PrimaryKey = keys; + + foreach (var cipher in ciphers) + { + var row = ciphersTable.NewRow(); + + row[idColumn] = cipher.Id; + row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; + row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)cipher.Type; + row[dataColumn] = cipher.Data; + row[favoritesColumn] = cipher.Favorites; + row[foldersColumn] = cipher.Folders; + row[attachmentsColumn] = cipher.Attachments; + row[creationDateColumn] = cipher.CreationDate; + row[revisionDateColumn] = cipher.RevisionDate; + row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[repromptColumn] = cipher.Reprompt.HasValue ? cipher.Reprompt.Value : DBNull.Value; + row[keyColummn] = cipher.Key; + + ciphersTable.Rows.Add(row); + } + + return ciphersTable; + } + + private static DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders, string errorMessage) + { + var f = folders.FirstOrDefault(); + + if (f == null) + { + throw new ApplicationException(errorMessage); + } + + var foldersTable = new DataTable("FolderDataTable"); + + var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); + foldersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); + foldersTable.Columns.Add(userIdColumn); + var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); + foldersTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); + foldersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); + foldersTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in foldersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + foldersTable.PrimaryKey = keys; + + foreach (var folder in folders) + { + var row = foldersTable.NewRow(); + + row[idColumn] = folder.Id; + row[userIdColumn] = folder.UserId; + row[nameColumn] = folder.Name; + row[creationDateColumn] = folder.CreationDate; + row[revisionDateColumn] = folder.RevisionDate; + + foldersTable.Rows.Add(row); + } + + return foldersTable; + } + + private static DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers, string errorMessage) + { + var cc = collectionCiphers.FirstOrDefault(); + + if (cc == null) + { + throw new ApplicationException(errorMessage); + } + + var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); + + var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); + collectionCiphersTable.Columns.Add(collectionIdColumn); + var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); + collectionCiphersTable.Columns.Add(cipherIdColumn); + + foreach (DataColumn col in collectionCiphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = cipherIdColumn; + collectionCiphersTable.PrimaryKey = keys; + + foreach (var collectionCipher in collectionCiphers) + { + var row = collectionCiphersTable.NewRow(); + + row[collectionIdColumn] = collectionCipher.CollectionId; + row[cipherIdColumn] = collectionCipher.CipherId; + + collectionCiphersTable.Rows.Add(row); + } + + return collectionCiphersTable; + } } diff --git a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs index 180a90fd41..8c1f04affc 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs @@ -10,6 +10,7 @@ using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; +using Bit.Infrastructure.Dapper.AdminConsole.Helpers; using Bit.Infrastructure.Dapper.Repositories; using Bit.Infrastructure.Dapper.Vault.Helpers; using Dapper; @@ -408,6 +409,52 @@ public class CipherRepository : Repository, ICipherRepository }; } + /// + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( + Guid userId, IEnumerable ciphers) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + // Create temp table + var sqlCreateTemp = @" + SELECT TOP 0 * + INTO #TempCipher + FROM [dbo].[Cipher]"; + + await using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // Bulk copy data into temp table + await BulkResourceCreationService.CreateTempCiphersAsync(connection, transaction, ciphers); + + // Update cipher table from temp table + var sql = @" + UPDATE + [dbo].[Cipher] + SET + [Data] = TC.[Data], + [Attachments] = TC.[Attachments], + [RevisionDate] = TC.[RevisionDate], + [Key] = TC.[Key] + FROM + [dbo].[Cipher] C + INNER JOIN + #TempCipher TC ON C.Id = TC.Id + WHERE + C.[UserId] = @UserId + + DROP TABLE #TempCipher"; + + await using (var cmd = new SqlCommand(sql, connection, transaction)) + { + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); + } + }; + } + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) { if (!ciphers.Any()) @@ -490,6 +537,83 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + // 1. Create temp tables to bulk copy into. + + var sqlCreateTemp = @" + SELECT TOP 0 * + INTO #TempCipher + FROM [dbo].[Cipher]"; + + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // 2. Bulk copy into temp tables. + await BulkResourceCreationService.CreateTempCiphersAsync(connection, transaction, ciphers); + + // 3. Insert into real tables from temp tables and clean up. + + // Intentionally not including Favorites, Folders, and CreationDate + // since those are not meant to be bulk updated at this time + var sql = @" + UPDATE + [dbo].[Cipher] + SET + [UserId] = TC.[UserId], + [OrganizationId] = TC.[OrganizationId], + [Type] = TC.[Type], + [Data] = TC.[Data], + [Attachments] = TC.[Attachments], + [RevisionDate] = TC.[RevisionDate], + [DeletedDate] = TC.[DeletedDate], + [Key] = TC.[Key] + FROM + [dbo].[Cipher] C + INNER JOIN + #TempCipher TC ON C.Id = TC.Id + WHERE + C.[UserId] = @UserId + + DROP TABLE #TempCipher"; + + using (var cmd = new SqlCommand(sql, connection, transaction)) + { + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = userId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task CreateAsync(Guid userId, IEnumerable ciphers, IEnumerable folders) { if (!ciphers.Any()) @@ -538,6 +662,44 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + if (folders.Any()) + { + await BulkResourceCreationService.CreateFoldersAsync(connection, transaction, folders); + } + + await BulkResourceCreationService.CreateCiphersAsync(connection, transaction, ciphers); + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = userId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, IEnumerable collectionUsers) { @@ -607,6 +769,55 @@ public class CipherRepository : Repository, ICipherRepository } } + public async Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers, IEnumerable collectionUsers) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + await BulkResourceCreationService.CreateCiphersAsync(connection, transaction, ciphers); + + if (collections.Any()) + { + await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); + } + + if (collectionCiphers.Any()) + { + await BulkResourceCreationService.CreateCollectionCiphersAsync(connection, transaction, collectionCiphers); + } + + if (collectionUsers.Any()) + { + await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", + new { OrganizationId = ciphers.First().OrganizationId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) { using (var connection = new SqlConnection(ConnectionString)) diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index 3fae537a1e..d595fe7cfe 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -167,6 +167,16 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular create method. + /// + public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, + IEnumerable folders) + { + await CreateAsync(userId, ciphers, folders); + } + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, @@ -205,6 +215,18 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular create method. + /// + public async Task CreateAsync_vNext(IEnumerable ciphers, + IEnumerable collections, + IEnumerable collectionCiphers, + IEnumerable collectionUsers) + { + await CreateAsync(ciphers, collections, collectionCiphers, collectionUsers); + } + public async Task DeleteAsync(IEnumerable ids, Guid userId) { await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); @@ -907,6 +929,15 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular update method. + /// + public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) + { + await UpdateCiphersAsync(userId, ciphers); + } + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -970,6 +1001,16 @@ public class CipherRepository : Repository + /// + /// EF does not use the bulk resource creation service, so we need to use the regular update method. + /// + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( + Guid userId, IEnumerable ciphers) + { + return UpdateForKeyRotation(userId, ciphers); + } + public async Task UpsertAsync(CipherDetails cipher) { if (cipher.Id.Equals(default)) diff --git a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs index 0cb0deaf52..11f637d207 100644 --- a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs +++ b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs @@ -47,7 +47,41 @@ public class ImportCiphersAsyncCommandTests await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); // Assert - await sutProvider.GetDependency().Received(1).CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_WithBulkResourceCreationServiceEnabled_Success( + Guid importingUserId, + List ciphers, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .AnyPoliciesApplicableToUserAsync(importingUserId, PolicyType.OrganizationDataOwnership) + .Returns(false); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List { new Folder { UserId = importingUserId } }; + + var folderRelationships = new List>(); + + // Act + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateAsync_vNext(importingUserId, ciphers, Arg.Any>()); await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } @@ -77,7 +111,45 @@ public class ImportCiphersAsyncCommandTests await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); - await sutProvider.GetDependency().Received(1).CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, ciphers, Arg.Any>()); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_WithBulkResourceCreationServiceEnabled_WithPolicyRequirementsEnabled_WithOrganizationDataOwnershipPolicyDisabled_Success( + Guid importingUserId, + List ciphers, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PolicyRequirements) + .Returns(true); + + sutProvider.GetDependency() + .GetAsync(importingUserId) + .Returns(new OrganizationDataOwnershipPolicyRequirement( + OrganizationDataOwnershipState.Disabled, + [])); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List { new Folder { UserId = importingUserId } }; + + var folderRelationships = new List>(); + + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync_vNext(importingUserId, ciphers, Arg.Any>()); await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } @@ -187,6 +259,66 @@ public class ImportCiphersAsyncCommandTests await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } + [Theory, BitAutoData] + public async Task ImportIntoOrganizationalVaultAsync_WithBulkResourceCreationServiceEnabled_Success( + Organization organization, + Guid importingUserId, + OrganizationUser importingOrganizationUser, + List collections, + List ciphers, + SutProvider sutProvider) + { + organization.MaxCollections = null; + importingOrganizationUser.OrganizationId = organization.Id; + + foreach (var collection in collections) + { + collection.OrganizationId = organization.Id; + } + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organization.Id; + } + + KeyValuePair[] collectionRelationships = { + new(0, 0), + new(1, 1), + new(2, 2) + }; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, importingUserId) + .Returns(importingOrganizationUser); + + // Set up a collection that already exists in the organization + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns(new List { collections[0] }); + + await sutProvider.Sut.ImportIntoOrganizationalVaultAsync(collections, ciphers, collectionRelationships, importingUserId); + + await sutProvider.GetDependency().Received(1).CreateAsync_vNext( + ciphers, + Arg.Is>(cols => cols.Count() == collections.Count - 1 && + !cols.Any(c => c.Id == collections[0].Id) && // Check that the collection that already existed in the organization was not added + cols.All(c => collections.Any(x => c.Name == x.Name))), + Arg.Is>(c => c.Count() == ciphers.Count), + Arg.Is>(cus => + cus.Count() == collections.Count - 1 && + !cus.Any(cu => cu.CollectionId == collections[0].Id) && // Check that access was not added for the collection that already existed in the organization + cus.All(cu => cu.OrganizationUserId == importingOrganizationUser.Id && cu.Manage == true))); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + } + [Theory, BitAutoData] public async Task ImportIntoOrganizationalVaultAsync_ThrowsBadRequestException( Organization organization, diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 55db5a9143..44c86389e3 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -674,6 +674,32 @@ public class CipherServiceTests Arg.Is>(arg => !arg.Except(ciphers).Any())); } + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task ShareManyAsync_CorrectRevisionDate_WithBulkResourceCreationServiceEnabled_Passes(string revisionDateString, + SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(organization.Id) + .Returns(new Organization + { + PlanType = PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + + var cipherInfos = ciphers.Select(c => (c, + string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync_vNext(sharingUserId, + Arg.Is>(arg => !arg.Except(ciphers).Any())); + } + [Theory] [BitAutoData] public async Task RestoreAsync_UpdatesUserCipher(Guid restoringUserId, CipherDetails cipher, SutProvider sutProvider) @@ -1094,6 +1120,33 @@ public class CipherServiceTests Arg.Is>(arg => !arg.Except(ciphers).Any())); } + [Theory, BitAutoData] + public async Task ShareManyAsync_PaidOrgWithAttachment_WithBulkResourceCreationServiceEnabled_Passes(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization + { + PlanType = PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync_vNext(sharingUserId, + Arg.Is>(arg => !arg.Except(ciphers).Any())); + } + private class SaveDetailsAsyncDependencies { public CipherDetails CipherDetails { get; set; } diff --git a/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs b/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs index 0a186e43be..2a31398a02 100644 --- a/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/Vault/Repositories/CipherRepositoryTests.cs @@ -8,11 +8,13 @@ using Bit.Core.Models.Data; using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Repositories; +using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Enums; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; using Xunit; +using CipherType = Bit.Core.Vault.Enums.CipherType; namespace Bit.Infrastructure.IntegrationTest.Repositories; @@ -975,6 +977,161 @@ public class CipherRepositoryTests Assert.Equal("new_attachments", updatedCipher2.Attachments); } + [DatabaseTheory, DatabaseData] + public async Task CreateAsync_vNext_WithFolders_Works( + IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository) + { + // Arrange + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var folder1 = new Folder { Id = CoreHelpers.GenerateComb(), UserId = user.Id, Name = "Test Folder 1" }; + var folder2 = new Folder { Id = CoreHelpers.GenerateComb(), UserId = user.Id, Name = "Test Folder 2" }; + var cipher1 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + var cipher2 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.SecureNote, UserId = user.Id, Data = "" }; + + // Act + await cipherRepository.CreateAsync_vNext( + userId: user.Id, + ciphers: [cipher1, cipher2], + folders: [folder1, folder2]); + + // Assert + var readCipher1 = await cipherRepository.GetByIdAsync(cipher1.Id); + var readCipher2 = await cipherRepository.GetByIdAsync(cipher2.Id); + Assert.NotNull(readCipher1); + Assert.NotNull(readCipher2); + + var readFolder1 = await folderRepository.GetByIdAsync(folder1.Id); + var readFolder2 = await folderRepository.GetByIdAsync(folder2.Id); + Assert.NotNull(readFolder1); + Assert.NotNull(readFolder2); + } + + [DatabaseTheory, DatabaseData] + public async Task CreateAsync_vNext_WithCollectionsAndUsers_Works( + IOrganizationRepository orgRepository, + IOrganizationUserRepository orgUserRepository, + ICollectionRepository collectionRepository, + ICollectionCipherRepository collectionCipherRepository, + ICipherRepository cipherRepository, + IUserRepository userRepository) + { + // Arrange + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var org = await orgRepository.CreateAsync(new Organization + { + Name = "Test Organization", + BillingEmail = user.Email, + Plan = "Test" + }); + + var orgUser = await orgUserRepository.CreateAsync(new OrganizationUser + { + UserId = user.Id, + OrganizationId = org.Id, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Owner, + }); + + var collection = new Collection { Id = CoreHelpers.GenerateComb(), Name = "Test Collection", OrganizationId = org.Id }; + var cipher = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, OrganizationId = org.Id, Data = "" }; + var collectionCipher = new CollectionCipher { CollectionId = collection.Id, CipherId = cipher.Id }; + var collectionUser = new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = orgUser.Id, + HidePasswords = false, + ReadOnly = false, + Manage = true + }; + + // Act + await cipherRepository.CreateAsync_vNext( + ciphers: [cipher], + collections: [collection], + collectionCiphers: [collectionCipher], + collectionUsers: [collectionUser]); + + // Assert + var orgCiphers = await cipherRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(orgCiphers, c => c.Id == cipher.Id); + + var collCiphers = await collectionCipherRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(collCiphers, cc => cc.CipherId == cipher.Id && cc.CollectionId == collection.Id); + + var collectionsInOrg = await collectionRepository.GetManyByOrganizationIdAsync(org.Id); + Assert.Contains(collectionsInOrg, c => c.Id == collection.Id); + + var collectionUsers = await collectionRepository.GetManyUsersByIdAsync(collection.Id); + var foundCollectionUser = collectionUsers.FirstOrDefault(cu => cu.Id == orgUser.Id); + Assert.NotNull(foundCollectionUser); + Assert.True(foundCollectionUser.Manage); + Assert.False(foundCollectionUser.ReadOnly); + Assert.False(foundCollectionUser.HidePasswords); + } + + [DatabaseTheory, DatabaseData] + public async Task UpdateCiphersAsync_vNext_Works( + IUserRepository userRepository, ICipherRepository cipherRepository) + { + // Arrange + var expectedNewType = CipherType.SecureNote; + var expectedNewAttachments = "bulk_new_attachments"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"{Guid.NewGuid()}@example.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var c1 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + var c2 = new Cipher { Id = CoreHelpers.GenerateComb(), Type = CipherType.Login, UserId = user.Id, Data = "" }; + await cipherRepository.CreateAsync( + userId: user.Id, + ciphers: [c1, c2], + folders: []); + + c1.Type = expectedNewType; + c2.Attachments = expectedNewAttachments; + + // Act + await cipherRepository.UpdateCiphersAsync_vNext(user.Id, [c1, c2]); + + // Assert + var updated1 = await cipherRepository.GetByIdAsync(c1.Id); + Assert.NotNull(updated1); + Assert.Equal(c1.Id, updated1.Id); + Assert.Equal(expectedNewType, updated1.Type); + Assert.Equal(c1.UserId, updated1.UserId); + Assert.Equal(c1.Data, updated1.Data); + Assert.Equal(c1.OrganizationId, updated1.OrganizationId); + Assert.Equal(c1.Attachments, updated1.Attachments); + + var updated2 = await cipherRepository.GetByIdAsync(c2.Id); + Assert.NotNull(updated2); + Assert.Equal(c2.Id, updated2.Id); + Assert.Equal(c2.Type, updated2.Type); + Assert.Equal(c2.UserId, updated2.UserId); + Assert.Equal(c2.Data, updated2.Data); + Assert.Equal(c2.OrganizationId, updated2.OrganizationId); + Assert.Equal(expectedNewAttachments, updated2.Attachments); + } + [DatabaseTheory, DatabaseData] public async Task DeleteCipherWithSecurityTaskAsync_Works( IOrganizationRepository organizationRepository,