From e042572cfb20bc5a89b0189f59696f1c0592f9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:36:40 +0100 Subject: [PATCH] [PM-24582] Bugfix: exclude admins and owners from default user collection creation on confirmation (#6177) * Update the OrganizationUserController integration Confirm tests to handle the Owner type * Refactor ConfirmOrganizationUserCommand to simplify side-effect handling in organization user confirmation. Update IPolicyRequirementQuery to return eligible org user IDs for policy enforcement. Update tests for method signature changes and default collection creation logic. --- .../ConfirmOrganizationUserCommand.cs | 46 +++++++------ .../Policies/IPolicyRequirementQuery.cs | 13 ++-- .../Implementations/PolicyRequirementQuery.cs | 14 ++-- .../OrganizationUserControllerTests.cs | 64 ++++++++++++------- .../ConfirmOrganizationUserCommandTests.cs | 23 ++----- .../Policies/PolicyRequirementQueryTests.cs | 34 +++++----- 6 files changed, 99 insertions(+), 95 deletions(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 0baa9c9e3a..cbedb6355d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -244,25 +244,6 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand .Select(d => d.Id.ToString()); } - private async Task OrganizationRequiresDefaultCollectionAsync(Guid organizationId, string defaultUserCollectionName) - { - if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)) - { - return false; - } - - // Skip if no collection name provided (backwards compatibility) - if (string.IsNullOrWhiteSpace(defaultUserCollectionName)) - { - return false; - } - - var organizationPolicyRequirement = await _policyRequirementQuery.GetByOrganizationAsync(organizationId); - - // Check if the organization requires default collections - return organizationPolicyRequirement.RequiresDefaultCollection(organizationId); - } - /// /// Handles the side effects of confirming an organization user. /// Creates a default collection for the user if the organization @@ -271,15 +252,32 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand /// The organization ID. /// The confirmed organization users. /// The encrypted default user collection name. - private async Task HandleConfirmationSideEffectsAsync(Guid organizationId, IEnumerable confirmedOrganizationUsers, string defaultUserCollectionName) + private async Task HandleConfirmationSideEffectsAsync(Guid organizationId, + IEnumerable confirmedOrganizationUsers, string defaultUserCollectionName) { - var requiresDefaultCollections = await OrganizationRequiresDefaultCollectionAsync(organizationId, defaultUserCollectionName); - if (!requiresDefaultCollections) + if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)) { return; } - var organizationUserIds = confirmedOrganizationUsers.Select(u => u.Id).ToList(); - await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultUserCollectionName); + // Skip if no collection name provided (backwards compatibility) + if (string.IsNullOrWhiteSpace(defaultUserCollectionName)) + { + return; + } + + var policyEligibleOrganizationUserIds = await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId); + + var eligibleOrganizationUserIds = confirmedOrganizationUsers + .Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id)) + .Select(ou => ou.Id) + .ToList(); + + if (eligibleOrganizationUserIds.Count == 0) + { + return; + } + + await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs index 226347fe29..e662716142 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs @@ -17,12 +17,11 @@ public interface IPolicyRequirementQuery Task GetAsync(Guid userId) where T : IPolicyRequirement; /// - /// Get a policy requirement for a specific organization. - /// This returns the policy requirement that represents the policy state for the entire organization. - /// It will always return a value even if there are no policies that should be enforced. - /// This should be used for organization-level policy checks. + /// Get all organization user IDs within an organization that are affected by a given policy type. + /// Respects role/status/provider exemptions via the policy factory's Enforce predicate. /// - /// The organization to check policies for. - /// The IPolicyRequirement that corresponds to the policy you want to enforce. - Task GetByOrganizationAsync(Guid organizationId) where T : IPolicyRequirement; + /// The organization to check. + /// The IPolicyRequirement that corresponds to the policy type to evaluate. + /// Organization user IDs for whom the policy applies within the organization. + Task> GetManyByOrganizationIdAsync(Guid organizationId) where T : IPolicyRequirement; } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index ba4495224c..e846e02e46 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -28,7 +28,8 @@ public class PolicyRequirementQuery( return requirement; } - public async Task GetByOrganizationAsync(Guid organizationId) where T : IPolicyRequirement + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + where T : IPolicyRequirement { var factory = factories.OfType>().SingleOrDefault(); if (factory is null) @@ -37,13 +38,14 @@ public class PolicyRequirementQuery( } var organizationPolicyDetails = await GetOrganizationPolicyDetails(organizationId, factory.PolicyType); - var filteredPolicies = organizationPolicyDetails - .Cast() - .Where(policyDetails => policyDetails.PolicyType == factory.PolicyType) + + var eligibleOrganizationUserIds = organizationPolicyDetails + .Where(p => p.PolicyType == factory.PolicyType) .Where(factory.Enforce) + .Select(p => p.OrganizationUserId) .ToList(); - var requirement = factory.Create(filteredPolicies); - return requirement; + + return eligibleOrganizationUserIds; } private Task> GetPolicyDetails(Guid userId) diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs index 08ebcf5de0..04ab72fad1 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs @@ -22,7 +22,6 @@ public class OrganizationUserControllerTests : IClassFixture (organizationUser, string.Format(testKeyFormat, index))).ToList()); - await VerifyMultipleUsersHaveDefaultCollectionsAsync(acceptedUsers); + await VerifyDefaultCollectionCountAsync(acceptedUsers.ElementAt(0), 1); + await VerifyDefaultCollectionCountAsync(acceptedUsers.ElementAt(1), 0); // Owner does not get a default collection + await VerifyDefaultCollectionCountAsync(acceptedUsers.ElementAt(2), 1); } [Fact] @@ -294,16 +320,18 @@ public class OrganizationUserControllerTests : IClassFixture> CreateAcceptedUsersAsync(IEnumerable emails) + private async Task> CreateAcceptedUsersAsync( + IEnumerable<(string email, OrganizationUserType userType)> newUsers) { var acceptedUsers = new List(); - foreach (var email in emails) + foreach (var (email, userType) in newUsers) { await _factory.LoginWithNewAccount(email); - var acceptedOrgUser = await OrganizationTestHelpers.CreateUserAsync(_factory, _organization.Id, email, - OrganizationUserType.User, userStatusType: OrganizationUserStatusType.Accepted); + var acceptedOrgUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, _organization.Id, email, + userType, userStatusType: OrganizationUserStatusType.Accepted); acceptedUsers.Add(acceptedOrgUser); } @@ -311,12 +339,11 @@ public class OrganizationUserControllerTests : IClassFixture(); var collections = await collectionRepository.GetManyByUserIdAsync(orgUser.UserId!.Value); - Assert.Single(collections); - Assert.Equal(_mockEncryptedString, collections.First().Name); + Assert.Equal(expectedCount, collections.Count); } private async Task VerifyUserConfirmedAsync(OrganizationUser orgUser, string expectedKey) @@ -334,15 +361,4 @@ public class OrganizationUserControllerTests : IClassFixture acceptedOrganizationUsers) - { - var collectionRepository = _factory.GetService(); - foreach (var acceptedOrganizationUser in acceptedOrganizationUsers) - { - var collections = await collectionRepository.GetManyByUserIdAsync(acceptedOrganizationUser.UserId!.Value); - Assert.Single(collections); - Assert.Equal(_mockEncryptedString, collections.First().Name); - } - } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index b0815d9f35..a8219ebcaa 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -10,7 +10,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Platform.Push; using Bit.Core.Repositories; @@ -473,10 +472,8 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); sutProvider.GetDependency() - .GetByOrganizationAsync(organization.Id) - .Returns(new OrganizationDataOwnershipPolicyRequirement( - OrganizationDataOwnershipState.Enabled, - [organization.Id])); + .GetManyByOrganizationIdAsync(organization.Id) + .Returns(new List { orgUser.Id }); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); @@ -504,17 +501,11 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); - sutProvider.GetDependency() - .GetByOrganizationAsync(org.Id) - .Returns(new OrganizationDataOwnershipPolicyRequirement( - OrganizationDataOwnershipState.Enabled, - [org.Id])); - await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, ""); await sutProvider.GetDependency() .DidNotReceive() - .CreateAsync(Arg.Any(), Arg.Any>(), Arg.Any>()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -533,15 +524,13 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); sutProvider.GetDependency() - .GetByOrganizationAsync(org.Id) - .Returns(new OrganizationDataOwnershipPolicyRequirement( - OrganizationDataOwnershipState.Enabled, - [Guid.NewGuid()])); + .GetManyByOrganizationIdAsync(org.Id) + .Returns(new List { orgUser.UserId!.Value }); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); await sutProvider.GetDependency() .DidNotReceive() - .CreateAsync(Arg.Any(), Arg.Any>(), Arg.Any>()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index da8f7319d5..8c25f70454 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -81,11 +81,11 @@ public class PolicyRequirementQueryTests } [Theory, BitAutoData] - public async Task GetByOrganizationAsync_IgnoresOtherPolicyTypes(Guid organizationId) + public async Task GetManyByOrganizationIdAsync_IgnoresOtherPolicyTypes(Guid organizationId) { var policyRepository = Substitute.For(); - var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; - var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = Guid.NewGuid() }; + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, OrganizationUserId = Guid.NewGuid() }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, OrganizationUserId = Guid.NewGuid() }; // Force the repository to return both policies even though that is not the expected result policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg) .Returns([thisPolicy, otherPolicy]); @@ -93,20 +93,20 @@ public class PolicyRequirementQueryTests var factory = new TestPolicyRequirementFactory(_ => true); var sut = new PolicyRequirementQuery(policyRepository, [factory]); - var requirement = await sut.GetByOrganizationAsync(organizationId); + var organizationUserIds = await sut.GetManyByOrganizationIdAsync(organizationId); await policyRepository.Received(1).GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg); - Assert.Contains(thisPolicy, requirement.Policies.Cast()); - Assert.DoesNotContain(otherPolicy, requirement.Policies.Cast()); + Assert.Contains(thisPolicy.OrganizationUserId, organizationUserIds); + Assert.DoesNotContain(otherPolicy.OrganizationUserId, organizationUserIds); } [Theory, BitAutoData] - public async Task GetByOrganizationAsync_CallsEnforceCallback(Guid organizationId) + public async Task GetManyByOrganizationIdAsync_CallsEnforceCallback(Guid organizationId) { var policyRepository = Substitute.For(); - var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; - var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = Guid.NewGuid() }; + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, OrganizationUserId = Guid.NewGuid() }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, OrganizationUserId = Guid.NewGuid() }; policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg).Returns([thisPolicy, otherPolicy]); var callback = Substitute.For>(); @@ -115,28 +115,28 @@ public class PolicyRequirementQueryTests var factory = new TestPolicyRequirementFactory(callback); var sut = new PolicyRequirementQuery(policyRepository, [factory]); - var requirement = await sut.GetByOrganizationAsync(organizationId); + var organizationUserIds = await sut.GetManyByOrganizationIdAsync(organizationId); - Assert.Contains(thisPolicy, requirement.Policies.Cast()); - Assert.DoesNotContain(otherPolicy, requirement.Policies.Cast()); + Assert.Contains(thisPolicy.OrganizationUserId, organizationUserIds); + Assert.DoesNotContain(otherPolicy.OrganizationUserId, organizationUserIds); callback.Received()(Arg.Is(p => p == thisPolicy)); callback.Received()(Arg.Is(p => p == otherPolicy)); } [Theory, BitAutoData] - public async Task GetByOrganizationAsync_ThrowsIfNoFactoryRegistered(Guid organizationId) + public async Task GetManyByOrganizationIdAsync_ThrowsIfNoFactoryRegistered(Guid organizationId) { var policyRepository = Substitute.For(); var sut = new PolicyRequirementQuery(policyRepository, []); var exception = await Assert.ThrowsAsync(() - => sut.GetByOrganizationAsync(organizationId)); + => sut.GetManyByOrganizationIdAsync(organizationId)); Assert.Contains("No Requirement Factory found", exception.Message); } [Theory, BitAutoData] - public async Task GetByOrganizationAsync_HandlesNoPolicies(Guid organizationId) + public async Task GetManyByOrganizationIdAsync_HandlesNoPolicies(Guid organizationId) { var policyRepository = Substitute.For(); policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, PolicyType.SingleOrg).Returns([]); @@ -144,8 +144,8 @@ public class PolicyRequirementQueryTests var factory = new TestPolicyRequirementFactory(x => x.IsProvider); var sut = new PolicyRequirementQuery(policyRepository, [factory]); - var requirement = await sut.GetByOrganizationAsync(organizationId); + var organizationUserIds = await sut.GetManyByOrganizationIdAsync(organizationId); - Assert.Empty(requirement.Policies); + Assert.Empty(organizationUserIds); } }