diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs index e662716142..2d6bd94fd1 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs @@ -16,6 +16,17 @@ public interface IPolicyRequirementQuery /// The IPolicyRequirement that corresponds to the policy you want to enforce. Task GetAsync(Guid userId) where T : IPolicyRequirement; + /// + /// Get a policy requirement for a list of users. + /// The policy requirement represents how one or more policy types should be enforced against the users. + /// + /// + /// A collection of tuples pairing each user ID with their corresponding policy requirement. + /// + /// The users that you need to enforce the policy against. + /// The IPolicyRequirement that corresponds to the policy you want to enforce. + Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement; + /// /// Get all organization user IDs within an organization that are affected by a given policy type. /// Respects role/status/provider exemptions via the policy factory's Enforce predicate. diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index c1450c6ab5..8090691540 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -11,6 +11,9 @@ public class PolicyRequirementQuery( : IPolicyRequirementQuery { public async Task GetAsync(Guid userId) where T : IPolicyRequirement + => (await GetAsync([userId])).Single().Requirement; + + public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement { var factory = factories.OfType>().SingleOrDefault(); if (factory is null) @@ -18,12 +21,15 @@ public class PolicyRequirementQuery( throw new NotImplementedException("No Requirement Factory found for " + typeof(T)); } - var policyDetails = await GetPolicyDetails(userId, factory.PolicyType); - var filteredPolicies = policyDetails - .Where(p => p.PolicyType == factory.PolicyType) - .Where(factory.Enforce); - var requirement = factory.Create(filteredPolicies); - return requirement; + var userIdList = userIds.ToList(); + + var policyDetailsByUser = (await GetPolicyDetails(userIdList, factory.PolicyType)) + .Where(factory.Enforce) + .ToLookup(l => l.UserId); + + var policyRequirements = userIdList.Select(u => (u, factory.Create(policyDetailsByUser[u]))); + + return policyRequirements; } public async Task> GetManyByOrganizationIdAsync(Guid organizationId) @@ -46,8 +52,8 @@ public class PolicyRequirementQuery( return eligibleOrganizationUserIds; } - private async Task> GetPolicyDetails(Guid userId, PolicyType policyType) - => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType); + private async Task> GetPolicyDetails(IEnumerable userIds, PolicyType policyType) + => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(userIds, policyType); private async Task> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType) => await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index 9115ae5ba1..823de89757 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -11,25 +11,6 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; [SutProviderCustomize] public class PolicyRequirementQueryTests { - [Theory, BitAutoData] - public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId) - { - var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId }; - var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId }; - var policyRepository = Substitute.For(); - policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( - Arg.Is>(ids => ids.Contains(userId)), PolicyType.SingleOrg) - .Returns([otherPolicy, thisPolicy]); - - var factory = new TestPolicyRequirementFactory(_ => true); - var sut = new PolicyRequirementQuery(policyRepository, [factory]); - - var requirement = await sut.GetAsync(userId); - - Assert.Contains(thisPolicy, requirement.Policies); - Assert.DoesNotContain(otherPolicy, requirement.Policies); - } - [Theory, BitAutoData] public async Task GetAsync_CallsEnforceCallback(Guid userId) { @@ -86,6 +67,134 @@ public class PolicyRequirementQueryTests Assert.Empty(requirement.Policies); } + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA }; + var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB }; + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Any>(), PolicyType.SingleOrg) + .Returns([policyA, policyB]); + + var factory = new TestPolicyRequirementFactory(_ => true); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); + + Assert.Equal(2, requirements.Count); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.DoesNotContain(policyB, requirements[0].Requirement.Policies); + Assert.Contains(policyB, requirements[1].Requirement.Policies); + Assert.DoesNotContain(policyA, requirements[1].Requirement.Policies); + } + + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_CallsEnforceCallback(Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA }; + var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB }; + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Any>(), PolicyType.SingleOrg) + .Returns([policyA, policyB]); + + var callback = Substitute.For>(); + callback(Arg.Any()).Returns(x => x.Arg() == policyA); + + var factory = new TestPolicyRequirementFactory(callback); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); + + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); + callback.Received()(Arg.Is(policyA)); + callback.Received()(Arg.Is(policyB)); + } + + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_FiltersOutPoliciesThatAreNotEnforced(Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + var enforcedPolicyA = new OrganizationPolicyDetails + { PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = false }; + var notEnforcedPolicyA = new OrganizationPolicyDetails + { PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = true }; + var enforcedPolicyB = new OrganizationPolicyDetails + { PolicyType = PolicyType.SingleOrg, UserId = userIdB, IsProvider = false }; + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Any>(), PolicyType.SingleOrg) + .Returns([enforcedPolicyA, notEnforcedPolicyA, enforcedPolicyB]); + + // Enforce returns false for providers (filtering them out) + var factory = new TestPolicyRequirementFactory(p => !p.IsProvider); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); + + Assert.Equal(2, requirements.Count); + Assert.Contains(enforcedPolicyA, requirements[0].Requirement.Policies); + Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Requirement.Policies); + Assert.Contains(enforcedPolicyB, requirements[1].Requirement.Policies); + } + + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_ThrowsIfNoFactoryRegistered(Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + var sut = new PolicyRequirementQuery(policyRepository, []); + + var exception = await Assert.ThrowsAsync(() + => sut.GetAsync([userIdA, userIdB])); + Assert.Contains("No Requirement Factory found", exception.Message); + } + + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_HandlesNoPolicies(Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Any>(), PolicyType.SingleOrg) + .Returns([]); + + var factory = new TestPolicyRequirementFactory(_ => true); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); + + Assert.Equal(2, requirements.Count); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Empty(requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); + } + + [Theory, BitAutoData] + public async Task GetAsync_WithMultipleUserIds_ReturnsEmptyRequirementForUserWithoutPolicies( + Guid userIdA, Guid userIdB) + { + var policyRepository = Substitute.For(); + var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA }; + // Only userIdA has a policy, userIdB has none + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Any>(), PolicyType.SingleOrg) + .Returns([policyA]); + + var factory = new TestPolicyRequirementFactory(_ => true); + var sut = new PolicyRequirementQuery(policyRepository, [factory]); + + var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList(); + + Assert.Equal(2, requirements.Count); + Assert.Equal(userIdA, requirements[0].UserId); + Assert.Equal(userIdB, requirements[1].UserId); + Assert.Contains(policyA, requirements[0].Requirement.Policies); + Assert.Empty(requirements[1].Requirement.Policies); + } + [Theory, BitAutoData] public async Task GetManyByOrganizationIdAsync_IgnoresOtherPolicyTypes(Guid organizationId) {