1
0
mirror of https://github.com/bitwarden/server synced 2026-02-18 10:23:27 +00:00

[PM-24759] Add Method for Retrieving Policy Requirements for Multiple Users (#6876)

* Adds new method for retrieving policy requirements for a collection of user IDs

* Use Single instead of First for explicit correctness

* Fix xmldoc

* Refactor return type to include user ID
This commit is contained in:
sven-bitwarden
2026-02-17 10:33:27 -06:00
committed by GitHub
parent 072f6c57a8
commit 0874163911
3 changed files with 153 additions and 27 deletions

View File

@@ -16,6 +16,17 @@ public interface IPolicyRequirementQuery
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement;
/// <summary>
/// Get a policy requirement for a list of users.
/// The policy requirement represents how one or more policy types should be enforced against the users.
/// </summary>
/// <returns>
/// A collection of tuples pairing each user ID with their corresponding policy requirement.
/// </returns>
/// <param name="userIds">The users that you need to enforce the policy against.</param>
/// <typeparam name="T">The IPolicyRequirement that corresponds to the policy you want to enforce.</typeparam>
Task<IEnumerable<(Guid UserId, T Requirement)>> GetAsync<T>(IEnumerable<Guid> userIds) where T : IPolicyRequirement;
/// <summary>
/// Get all organization user IDs within an organization that are affected by a given policy type.
/// Respects role/status/provider exemptions via the policy factory's Enforce predicate.

View File

@@ -11,6 +11,9 @@ public class PolicyRequirementQuery(
: IPolicyRequirementQuery
{
public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
=> (await GetAsync<T>([userId])).Single().Requirement;
public async Task<IEnumerable<(Guid UserId, T Requirement)>> GetAsync<T>(IEnumerable<Guid> userIds) where T : IPolicyRequirement
{
var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
if (factory is null)
@@ -18,12 +21,15 @@ public class PolicyRequirementQuery(
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}
var policyDetails = await GetPolicyDetails(userId, factory.PolicyType);
var filteredPolicies = policyDetails
.Where(p => p.PolicyType == factory.PolicyType)
.Where(factory.Enforce);
var requirement = factory.Create(filteredPolicies);
return requirement;
var userIdList = userIds.ToList();
var policyDetailsByUser = (await GetPolicyDetails(userIdList, factory.PolicyType))
.Where(factory.Enforce)
.ToLookup(l => l.UserId);
var policyRequirements = userIdList.Select(u => (u, factory.Create(policyDetailsByUser[u])));
return policyRequirements;
}
public async Task<IEnumerable<Guid>> GetManyByOrganizationIdAsync<T>(Guid organizationId)
@@ -46,8 +52,8 @@ public class PolicyRequirementQuery(
return eligibleOrganizationUserIds;
}
private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(Guid userId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType);
private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(IEnumerable<Guid> userIds, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(userIds, policyType);
private async Task<IEnumerable<OrganizationPolicyDetails>> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);

View File

@@ -11,25 +11,6 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
[SutProviderCustomize]
public class PolicyRequirementQueryTests
{
[Theory, BitAutoData]
public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId)
{
var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId };
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([otherPolicy, thisPolicy]);
var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirement = await sut.GetAsync<TestPolicyRequirement>(userId);
Assert.Contains(thisPolicy, requirement.Policies);
Assert.DoesNotContain(otherPolicy, requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetAsync_CallsEnforceCallback(Guid userId)
{
@@ -86,6 +67,134 @@ public class PolicyRequirementQueryTests
Assert.Empty(requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_ReturnsRequirementPerUser(Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB };
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Any<IEnumerable<Guid>>(), PolicyType.SingleOrg)
.Returns([policyA, policyB]);
var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirements = (await sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
Assert.Equal(userIdA, requirements[0].UserId);
Assert.Equal(userIdB, requirements[1].UserId);
Assert.Contains(policyA, requirements[0].Requirement.Policies);
Assert.DoesNotContain(policyB, requirements[0].Requirement.Policies);
Assert.Contains(policyB, requirements[1].Requirement.Policies);
Assert.DoesNotContain(policyA, requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_CallsEnforceCallback(Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
var policyB = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdB };
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Any<IEnumerable<Guid>>(), PolicyType.SingleOrg)
.Returns([policyA, policyB]);
var callback = Substitute.For<Func<PolicyDetails, bool>>();
callback(Arg.Any<PolicyDetails>()).Returns(x => x.Arg<PolicyDetails>() == policyA);
var factory = new TestPolicyRequirementFactory(callback);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirements = (await sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB])).ToList();
Assert.Contains(policyA, requirements[0].Requirement.Policies);
Assert.Empty(requirements[1].Requirement.Policies);
callback.Received()(Arg.Is(policyA));
callback.Received()(Arg.Is(policyB));
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_FiltersOutPoliciesThatAreNotEnforced(Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var enforcedPolicyA = new OrganizationPolicyDetails
{ PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = false };
var notEnforcedPolicyA = new OrganizationPolicyDetails
{ PolicyType = PolicyType.SingleOrg, UserId = userIdA, IsProvider = true };
var enforcedPolicyB = new OrganizationPolicyDetails
{ PolicyType = PolicyType.SingleOrg, UserId = userIdB, IsProvider = false };
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Any<IEnumerable<Guid>>(), PolicyType.SingleOrg)
.Returns([enforcedPolicyA, notEnforcedPolicyA, enforcedPolicyB]);
// Enforce returns false for providers (filtering them out)
var factory = new TestPolicyRequirementFactory(p => !p.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirements = (await sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
Assert.Contains(enforcedPolicyA, requirements[0].Requirement.Policies);
Assert.DoesNotContain(notEnforcedPolicyA, requirements[0].Requirement.Policies);
Assert.Contains(enforcedPolicyB, requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_ThrowsIfNoFactoryRegistered(Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var sut = new PolicyRequirementQuery(policyRepository, []);
var exception = await Assert.ThrowsAsync<NotImplementedException>(()
=> sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB]));
Assert.Contains("No Requirement Factory found", exception.Message);
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_HandlesNoPolicies(Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Any<IEnumerable<Guid>>(), PolicyType.SingleOrg)
.Returns([]);
var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirements = (await sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
Assert.Equal(userIdA, requirements[0].UserId);
Assert.Equal(userIdB, requirements[1].UserId);
Assert.Empty(requirements[0].Requirement.Policies);
Assert.Empty(requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetAsync_WithMultipleUserIds_ReturnsEmptyRequirementForUserWithoutPolicies(
Guid userIdA, Guid userIdB)
{
var policyRepository = Substitute.For<IPolicyRepository>();
var policyA = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userIdA };
// Only userIdA has a policy, userIdB has none
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Any<IEnumerable<Guid>>(), PolicyType.SingleOrg)
.Returns([policyA]);
var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirements = (await sut.GetAsync<TestPolicyRequirement>([userIdA, userIdB])).ToList();
Assert.Equal(2, requirements.Count);
Assert.Equal(userIdA, requirements[0].UserId);
Assert.Equal(userIdB, requirements[1].UserId);
Assert.Contains(policyA, requirements[0].Requirement.Policies);
Assert.Empty(requirements[1].Requirement.Policies);
}
[Theory, BitAutoData]
public async Task GetManyByOrganizationIdAsync_IgnoresOtherPolicyTypes(Guid organizationId)
{