diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
index e662716142..2d6bd94fd1 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyRequirementQuery.cs
@@ -16,6 +16,17 @@ public interface IPolicyRequirementQuery
/// The IPolicyRequirement that corresponds to the policy you want to enforce.
Task GetAsync(Guid userId) where T : IPolicyRequirement;
+ ///
+ /// Get a policy requirement for a list of users.
+ /// The policy requirement represents how one or more policy types should be enforced against the users.
+ ///
+ ///
+ /// A collection of tuples pairing each user ID with their corresponding policy requirement.
+ ///
+ /// The users that you need to enforce the policy against.
+ /// The IPolicyRequirement that corresponds to the policy you want to enforce.
+ Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement;
+
///
/// Get all organization user IDs within an organization that are affected by a given policy type.
/// Respects role/status/provider exemptions via the policy factory's Enforce predicate.
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
index c1450c6ab5..8090691540 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs
@@ -11,6 +11,9 @@ public class PolicyRequirementQuery(
: IPolicyRequirementQuery
{
public async Task GetAsync(Guid userId) where T : IPolicyRequirement
+ => (await GetAsync([userId])).Single().Requirement;
+
+ public async Task> GetAsync(IEnumerable userIds) where T : IPolicyRequirement
{
var factory = factories.OfType>().SingleOrDefault();
if (factory is null)
@@ -18,12 +21,15 @@ public class PolicyRequirementQuery(
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}
- var policyDetails = await GetPolicyDetails(userId, factory.PolicyType);
- var filteredPolicies = policyDetails
- .Where(p => p.PolicyType == factory.PolicyType)
- .Where(factory.Enforce);
- var requirement = factory.Create(filteredPolicies);
- return requirement;
+ var userIdList = userIds.ToList();
+
+ var policyDetailsByUser = (await GetPolicyDetails(userIdList, factory.PolicyType))
+ .Where(factory.Enforce)
+ .ToLookup(l => l.UserId);
+
+ var policyRequirements = userIdList.Select(u => (u, factory.Create(policyDetailsByUser[u])));
+
+ return policyRequirements;
}
public async Task> GetManyByOrganizationIdAsync(Guid organizationId)
@@ -46,8 +52,8 @@ public class PolicyRequirementQuery(
return eligibleOrganizationUserIds;
}
- private async Task> GetPolicyDetails(Guid userId, PolicyType policyType)
- => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType);
+ private async Task> GetPolicyDetails(IEnumerable userIds, PolicyType policyType)
+ => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(userIds, policyType);
private async Task> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);
diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
index 9115ae5ba1..823de89757 100644
--- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
+++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs
@@ -11,25 +11,6 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
[SutProviderCustomize]
public class PolicyRequirementQueryTests
{
- [Theory, BitAutoData]
- public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId)
- {
- var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
- var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId };
- var policyRepository = Substitute.For();
- policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
- Arg.Is>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
- .Returns([otherPolicy, thisPolicy]);
-
- var factory = new TestPolicyRequirementFactory(_ => true);
- var sut = new PolicyRequirementQuery(policyRepository, [factory]);
-
- var requirement = await sut.GetAsync(userId);
-
- Assert.Contains(thisPolicy, requirement.Policies);
- Assert.DoesNotContain(otherPolicy, requirement.Policies);
- }
-
[Theory, BitAutoData]
public async Task GetAsync_CallsEnforceCallback(Guid userId)
{
@@ -86,6 +67,134 @@ public class PolicyRequirementQueryTests
Assert.Empty(requirement.Policies);
}
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
+ var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB };
+ policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
+ Arg.Any>(), PolicyType.SingleOrg)
+ .Returns([policyA, policyB]);
+
+ var factory = new TestPolicyRequirementFactory(_ => true);
+ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
+
+ var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
+
+ Assert.Equal(2, requirements.Count);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.DoesNotContain(policyB, requirements[0].Requirement.Policies);
+ Assert.Contains(policyB, requirements[1].Requirement.Policies);
+ Assert.DoesNotContain(policyA, requirements[1].Requirement.Policies);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_CallsEnforceCallback(Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
+ var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB };
+ policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
+ Arg.Any>(), PolicyType.SingleOrg)
+ .Returns([policyA, policyB]);
+
+ var callback = Substitute.For>();
+ callback(Arg.Any()).Returns(x => x.Arg() == policyA);
+
+ var factory = new TestPolicyRequirementFactory(callback);
+ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
+
+ var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
+
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
+ callback.Received()(Arg.Is(policyA));
+ callback.Received()(Arg.Is(policyB));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_FiltersOutPoliciesThatAreNotEnforced(Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ var enforcedPolicyA = new OrganizationPolicyDetails
+ { PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = false };
+ var notEnforcedPolicyA = new OrganizationPolicyDetails
+ { PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = true };
+ var enforcedPolicyB = new OrganizationPolicyDetails
+ { PolicyType = PolicyType.SingleOrg, UserId = userIdB, IsProvider = false };
+ policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
+ Arg.Any>(), PolicyType.SingleOrg)
+ .Returns([enforcedPolicyA, notEnforcedPolicyA, enforcedPolicyB]);
+
+ // Enforce returns false for providers (filtering them out)
+ var factory = new TestPolicyRequirementFactory(p => !p.IsProvider);
+ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
+
+ var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
+
+ Assert.Equal(2, requirements.Count);
+ Assert.Contains(enforcedPolicyA, requirements[0].Requirement.Policies);
+ Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Requirement.Policies);
+ Assert.Contains(enforcedPolicyB, requirements[1].Requirement.Policies);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_ThrowsIfNoFactoryRegistered(Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ var sut = new PolicyRequirementQuery(policyRepository, []);
+
+ var exception = await Assert.ThrowsAsync(()
+ => sut.GetAsync([userIdA, userIdB]));
+ Assert.Contains("No Requirement Factory found", exception.Message);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_HandlesNoPolicies(Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
+ Arg.Any>(), PolicyType.SingleOrg)
+ .Returns([]);
+
+ var factory = new TestPolicyRequirementFactory(_ => true);
+ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
+
+ var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
+
+ Assert.Equal(2, requirements.Count);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Empty(requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetAsync_WithMultipleUserIds_ReturnsEmptyRequirementForUserWithoutPolicies(
+ Guid userIdA, Guid userIdB)
+ {
+ var policyRepository = Substitute.For();
+ var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
+ // Only userIdA has a policy, userIdB has none
+ policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
+ Arg.Any>(), PolicyType.SingleOrg)
+ .Returns([policyA]);
+
+ var factory = new TestPolicyRequirementFactory(_ => true);
+ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
+
+ var requirements = (await sut.GetAsync([userIdA, userIdB])).ToList();
+
+ Assert.Equal(2, requirements.Count);
+ Assert.Equal(userIdA, requirements[0].UserId);
+ Assert.Equal(userIdB, requirements[1].UserId);
+ Assert.Contains(policyA, requirements[0].Requirement.Policies);
+ Assert.Empty(requirements[1].Requirement.Policies);
+ }
+
[Theory, BitAutoData]
public async Task GetManyByOrganizationIdAsync_IgnoresOtherPolicyTypes(Guid organizationId)
{