diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index 18bf62da65..22d9fd20dd 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -64,16 +64,26 @@ public class OrganizationDataOwnershipPolicyValidator( var userOrgIds = requirements .Select(requirement => requirement.GetDefaultCollectionRequestOnPolicyEnable(policyUpdate.OrganizationId)) .Where(request => request.ShouldCreateDefaultCollection) - .Select(request => request.OrganizationUserId); + .Select(request => request.OrganizationUserId) + .ToList(); if (!userOrgIds.Any()) { return; } - await collectionRepository.UpsertDefaultCollectionsBulkAsync( + // Filter out users who already have default collections + var existingSemaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(userOrgIds); + var usersNeedingDefaultCollections = userOrgIds.Except(existingSemaphores).ToList(); + + if (!usersNeedingDefaultCollections.Any()) + { + return; + } + + await collectionRepository.CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, - userOrgIds, + usersNeedingDefaultCollections, defaultCollectionName); } } diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index f51f64a7d8..91232db058 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -74,12 +74,17 @@ public interface ICollectionRepository : IRepository /// /// Creates default user collections for the specified organization users using bulk insert operations. - /// Gracefully skips users who already have a default collection for the organization. + /// Use this if you need to create collections for > ~1k users. + /// Throws an exception if any user already has a default collection for the organization. /// /// The Organization ID. /// The Organization User IDs to create default collections for. /// The encrypted string to use as the default collection name. - Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + /// + /// If any of the OrganizationUsers may already have default collections, the caller should first filter out these + /// users using GetDefaultCollectionSemaphoresAsync before calling this method. + /// + Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); /// /// Gets default collection semaphores for the given organizationUserIds. diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index df82a0c158..ac289eafb9 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -383,7 +383,7 @@ public class CollectionRepository : Repository, ICollectionRep commandType: CommandType.StoredProcedure); } - public async Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -391,8 +391,7 @@ public class CollectionRepository : Repository, ICollectionRep return; } - var orgUserIdWithDefaultCollection = await GetDefaultCollectionSemaphoresAsync(organizationUserIds); - var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); + var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, organizationUserIds, defaultCollectionName); await using var connection = new SqlConnection(ConnectionString); connection.Open(); @@ -400,15 +399,9 @@ public class CollectionRepository : Repository, ICollectionRep try { - var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); - - if (!collectionUsers.Any() || !collections.Any()) - { - return; - } // CRITICAL: Insert semaphore entries BEFORE collections - // TODO: this will result in a creation date of the semaphore AFTER that of the collection, which is weird + // Database will throw on duplicate primary key (OrganizationUserId) var now = DateTime.UtcNow; var semaphores = collectionUsers.Select(c => new DefaultCollectionSemaphore { diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index d152a67afb..1fb387f48a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -803,21 +803,13 @@ public class CollectionRepository : Repository new DefaultCollectionSemaphore { @@ -832,7 +824,7 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { // EF uses the same bulk copy approach as the main method await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs index 0c59bbe0cf..b61888759a 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs @@ -40,7 +40,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -66,7 +66,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -92,7 +92,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -118,7 +118,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceive() - .UpsertDefaultCollectionsBulkAsync( + .CreateDefaultCollectionsBulkAsync( Arg.Any(), Arg.Any>(), Arg.Any()); @@ -198,6 +198,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList); var collectionRepository = Substitute.For(); + // Mock GetDefaultCollectionSemaphoresAsync to return empty set (no existing collections) + collectionRepository + .GetDefaultCollectionSemaphoresAsync(Arg.Any>()) + .Returns(new HashSet()); + var sut = ArrangeSut(factory, policyRepository, collectionRepository); var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); @@ -207,12 +212,101 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .Received(1) - .UpsertDefaultCollectionsBulkAsync( + .GetDefaultCollectionSemaphoresAsync(Arg.Is>(ids => ids.Count() == 3)); + + await collectionRepository + .Received(1) + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); } + [Theory] + [BitMemberAutoData(nameof(ShouldUpsertDefaultCollectionsTestCases))] + public async Task ExecuteSideEffectsAsync_FiltersOutUsersWithExistingCollections( + Policy postUpdatedPolicy, + Policy? previousPolicyState, + [PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate, + [OrganizationPolicyDetails(PolicyType.OrganizationDataOwnership)] IEnumerable orgPolicyDetails, + OrganizationDataOwnershipPolicyRequirementFactory factory) + { + // Arrange + var orgPolicyDetailsList = orgPolicyDetails.ToList(); + foreach (var policyDetail in orgPolicyDetailsList) + { + policyDetail.OrganizationId = policyUpdate.OrganizationId; + } + + var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList); + var collectionRepository = Substitute.For(); + + // Mock GetDefaultCollectionSemaphoresAsync to return one existing user + var existingUserId = orgPolicyDetailsList[0].OrganizationUserId; + collectionRepository + .GetDefaultCollectionSemaphoresAsync(Arg.Any>()) + .Returns([existingUserId]); + + var sut = ArrangeSut(factory, policyRepository, collectionRepository); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert - Should filter out the existing user + await collectionRepository + .Received(1) + .GetDefaultCollectionSemaphoresAsync(Arg.Is>(ids => ids.Count() == 3)); + + await collectionRepository + .Received(1) + .CreateDefaultCollectionsBulkAsync( + policyUpdate.OrganizationId, + Arg.Is>(ids => ids.Count() == 2 && !ids.Contains(existingUserId)), + _defaultUserCollectionName); + } + + [Theory] + [BitMemberAutoData(nameof(ShouldUpsertDefaultCollectionsTestCases))] + public async Task ExecuteSideEffectsAsync_DoesNotCallRepository_WhenAllUsersHaveExistingCollections( + Policy postUpdatedPolicy, + Policy? previousPolicyState, + [PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate, + [OrganizationPolicyDetails(PolicyType.OrganizationDataOwnership)] IEnumerable orgPolicyDetails, + OrganizationDataOwnershipPolicyRequirementFactory factory) + { + // Arrange + var orgPolicyDetailsList = orgPolicyDetails.ToList(); + foreach (var policyDetail in orgPolicyDetailsList) + { + policyDetail.OrganizationId = policyUpdate.OrganizationId; + } + + var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList); + var collectionRepository = Substitute.For(); + + // Mock GetDefaultCollectionSemaphoresAsync to return all users + var allUserIds = orgPolicyDetailsList.Select(p => p.OrganizationUserId).ToHashSet(); + collectionRepository + .GetDefaultCollectionSemaphoresAsync(Arg.Any>()) + .Returns(allUserIds); + + var sut = ArrangeSut(factory, policyRepository, collectionRepository); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert - Should not call CreateDefaultCollectionsBulkAsync when all users already have collections + await collectionRepository + .Received(1) + .GetDefaultCollectionSemaphoresAsync(Arg.Is>(ids => ids.Count() == 3)); + + await collectionRepository + .DidNotReceive() + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + private static IEnumerable WhenDefaultCollectionsDoesNotExistTestCases() { yield return [new OrganizationModelOwnershipPolicyModel(null)]; @@ -246,7 +340,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } private static IPolicyRepository ArrangePolicyRepository(IEnumerable policyDetails) @@ -294,7 +388,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsBulkAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -320,7 +414,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsBulkAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -346,7 +440,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsBulkAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -372,7 +466,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsBulkAsync( + .CreateDefaultCollectionsBulkAsync( default, default, default); @@ -403,6 +497,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList); var collectionRepository = Substitute.For(); + // Mock GetDefaultCollectionSemaphoresAsync to return empty set (no existing collections) + collectionRepository + .GetDefaultCollectionSemaphoresAsync(Arg.Any>()) + .Returns(new HashSet()); + var sut = ArrangeSut(factory, policyRepository, collectionRepository); var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); @@ -412,7 +511,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .Received(1) - .UpsertDefaultCollectionsBulkAsync( + .GetDefaultCollectionSemaphoresAsync(Arg.Is>(ids => ids.Count() == 3)); + + await collectionRepository + .Received(1) + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -444,6 +547,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsBulkAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } } diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsBulkTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs similarity index 66% rename from test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsBulkTests.cs rename to test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs index 1376465a5b..5851d8d468 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsBulkTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs @@ -6,10 +6,10 @@ using Xunit; namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; -public class UpsertDefaultCollectionsBulkTests +public class CreateDefaultCollectionsBulkTests { [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsBulkAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection( + public async Task CreateDefaultCollectionsBulkAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection( IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -27,7 +27,7 @@ public class UpsertDefaultCollectionsBulkTests var defaultCollectionName = $"default-name-{organization.Id}"; // Act - await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); // Assert await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); @@ -37,7 +37,7 @@ public class UpsertDefaultCollectionsBulkTests } [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsBulkAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist( + public async Task CreateDefaultCollectionsBulkAsync_CreatesForNewUsersOnly_WhenCallerFiltersExisting( IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -64,18 +64,20 @@ public class UpsertDefaultCollectionsBulkTests var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers); var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id).ToList(); - // Act - await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + // Act - Caller filters out existing users (new pattern) + var existingSemaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(affectedOrgUserIds); + var usersNeedingCollections = affectedOrgUserIds.Except(existingSemaphores).ToList(); + await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, usersNeedingCollections, defaultCollectionName); - // Assert - await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id); + // Assert - All users now have exactly one collection + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, affectedOrgUsers, organization.Id); await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds); await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers); } [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsBulkAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne( + public async Task CreateDefaultCollectionsBulkAsync_ThrowsException_WhenUsersAlreadyHaveOne( IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -94,21 +96,61 @@ public class UpsertDefaultCollectionsBulkTests await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); - // Act - await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + // Act - Try to create again, should throw database constraint exception + await Assert.ThrowsAnyAsync(() => + collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName)); - // Assert + // Assert - Original collections should remain unchanged await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds); await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); } + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_ThrowsException_WhenDuplicatesNotFiltered( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + // Arrange + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var existingUser = await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization); + var newUser = await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization); + var defaultCollectionName = $"default-name-{organization.Id}"; + + // Create collection for existing user + await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, [existingUser.Id], defaultCollectionName); + + // Act - Try to create for both without filtering (incorrect usage) + await Assert.ThrowsAnyAsync(() => + collectionRepository.CreateDefaultCollectionsBulkAsync( + organization.Id, + [existingUser.Id, newUser.Id], + defaultCollectionName)); + + // Assert - Verify existing user still has collection + var existingUserCollections = await collectionRepository.GetManyByUserIdAsync(existingUser.UserId!.Value); + var existingUserDefaultCollection = existingUserCollections + .SingleOrDefault(c => c.OrganizationId == organization.Id && c.Type == CollectionType.DefaultUserCollection); + Assert.NotNull(existingUserDefaultCollection); + + // Verify new user does NOT have collection (transaction rolled back) + var newUserCollections = await collectionRepository.GetManyByUserIdAsync(newUser.UserId!.Value); + var newUserDefaultCollection = newUserCollections + .FirstOrDefault(c => c.OrganizationId == organization.Id && c.Type == CollectionType.DefaultUserCollection); + Assert.Null(newUserDefaultCollection); + + await CleanupAsync(organizationRepository, userRepository, organization, [existingUser, newUser]); + } + private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository, Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName, OrganizationUser[] resultOrganizationUsers) { - await collectionRepository.UpsertDefaultCollectionsBulkAsync(organizationId, affectedOrgUserIds, defaultCollectionName); + await collectionRepository.CreateDefaultCollectionsBulkAsync(organizationId, affectedOrgUserIds, defaultCollectionName); await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId); }