1
0
mirror of https://github.com/bitwarden/server synced 2025-12-29 22:54:00 +00:00

Merge branch 'jmccannon/ac/pm-26377-provider-auto-confirm' into jmccannon/ac/pm-27131-auto-confirm-req

This commit is contained in:
Jared McCannon
2025-12-04 13:45:10 -06:00
8 changed files with 398 additions and 322 deletions

View File

@@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
@@ -17,26 +18,13 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
/// <li>All organization users are compliant with the Single organization policy</li>
/// <li>No provider users exist</li>
/// </ul>
///
/// This class also performs side effects when the policy is being enabled or disabled. They are:
/// <ul>
/// <li>Sets the UseAutomaticUserConfirmation organization feature to match the policy update</li>
/// </ul>
/// </summary>
public class AutomaticUserConfirmationPolicyEventHandler(
IOrganizationUserRepository organizationUserRepository,
IProviderUserRepository providerUserRepository,
IPolicyRepository policyRepository,
IOrganizationRepository organizationRepository,
TimeProvider timeProvider)
: IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent, IEnforceDependentPoliciesEvent
IProviderUserRepository providerUserRepository)
: IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent
{
public PolicyType Type => PolicyType.AutomaticUserConfirmation;
public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) =>
await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy);
private const string _singleOrgPolicyNotEnabledErrorMessage =
"The Single organization policy must be enabled before enabling the Automatically confirm invited users policy.";
private const string _usersNotCompliantWithSingleOrgErrorMessage =
"All organization users must be compliant with the Single organization policy before enabling the Automatically confirm invited users policy. Please remove users who are members of multiple organizations.";
@@ -61,27 +49,20 @@ public class AutomaticUserConfirmationPolicyEventHandler(
public async Task<string> ValidateAsync(SavePolicyModel savePolicyModel, Policy? currentPolicy) =>
await ValidateAsync(savePolicyModel.PolicyUpdate, currentPolicy);
public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy)
{
var organization = await organizationRepository.GetByIdAsync(policyUpdate.OrganizationId);
if (organization is not null)
{
organization.UseAutomaticUserConfirmation = policyUpdate.Enabled;
organization.RevisionDate = timeProvider.GetUtcNow().UtcDateTime;
await organizationRepository.UpsertAsync(organization);
}
}
public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) =>
Task.CompletedTask;
private async Task<string> ValidateEnablingPolicyAsync(Guid organizationId)
{
var singleOrgValidationError = await ValidateSingleOrgPolicyComplianceAsync(organizationId);
var organizationUsers = await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var singleOrgValidationError = await ValidateUserComplianceWithSingleOrgAsync(organizationId, organizationUsers);
if (!string.IsNullOrWhiteSpace(singleOrgValidationError))
{
return singleOrgValidationError;
}
var providerValidationError = await ValidateNoProviderUsersAsync(organizationId);
var providerValidationError = await ValidateNoProviderUsersAsync(organizationUsers);
if (!string.IsNullOrWhiteSpace(providerValidationError))
{
return providerValidationError;
@@ -90,42 +71,24 @@ public class AutomaticUserConfirmationPolicyEventHandler(
return string.Empty;
}
private async Task<string> ValidateSingleOrgPolicyComplianceAsync(Guid organizationId)
private async Task<string> ValidateUserComplianceWithSingleOrgAsync(Guid organizationId,
ICollection<OrganizationUserUserDetails> organizationUsers)
{
var singleOrgPolicy = await policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.SingleOrg);
if (singleOrgPolicy is not { Enabled: true })
{
return _singleOrgPolicyNotEnabledErrorMessage;
}
return await ValidateUserComplianceWithSingleOrgAsync(organizationId);
}
private async Task<string> ValidateUserComplianceWithSingleOrgAsync(Guid organizationId)
{
var organizationUsers = (await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId))
.Where(ou => ou.Status != OrganizationUserStatusType.Invited &&
ou.Status != OrganizationUserStatusType.Revoked &&
ou.UserId.HasValue)
.ToList();
if (organizationUsers.Count == 0)
{
return string.Empty;
}
var hasNonCompliantUser = (await organizationUserRepository.GetManyByManyUsersAsync(
organizationUsers.Select(ou => ou.UserId!.Value)))
.Any(uo => uo.OrganizationId != organizationId &&
uo.Status != OrganizationUserStatusType.Invited);
.Any(uo => uo.OrganizationId != organizationId
&& uo.Status != OrganizationUserStatusType.Invited);
return hasNonCompliantUser ? _usersNotCompliantWithSingleOrgErrorMessage : string.Empty;
}
private async Task<string> ValidateNoProviderUsersAsync(Guid organizationId)
private async Task<string> ValidateNoProviderUsersAsync(ICollection<OrganizationUserUserDetails> organizationUsers)
{
var providerUsers = await providerUserRepository.GetManyByOrganizationAsync(organizationId);
var userIds = organizationUsers.Where(x => x.UserId is not null)
.Select(x => x.UserId!.Value);
return providerUsers.Count > 0 ? _providerUsersExistErrorMessage : string.Empty;
return (await providerUserRepository.GetManyByManyUsersAsync(userIds)).Count != 0
? _providerUsersExistErrorMessage
: string.Empty;
}
}

View File

@@ -12,6 +12,7 @@ public interface IProviderUserRepository : IRepository<ProviderUser, Guid>
Task<int> GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers);
Task<ICollection<ProviderUser>> GetManyAsync(IEnumerable<Guid> ids);
Task<ICollection<ProviderUser>> GetManyByUserAsync(Guid userId);
Task<ICollection<ProviderUser>> GetManyByManyUsersAsync(IEnumerable<Guid> userIds);
Task<ProviderUser?> GetByProviderUserAsync(Guid providerId, Guid userId);
Task<ICollection<ProviderUser>> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null);
Task<ICollection<ProviderUserUserDetails>> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status = null);

View File

@@ -61,6 +61,18 @@ public class ProviderUserRepository : Repository<ProviderUser, Guid>, IProviderU
}
}
public async Task<ICollection<ProviderUser>> GetManyByManyUsersAsync(IEnumerable<Guid> userIds)
{
await using var connection = new SqlConnection(ConnectionString);
var results = await connection.QueryAsync<ProviderUser>(
"[dbo].[ProviderUser_ReadManyByManyUserIds]",
new { UserIds = userIds.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
public async Task<ProviderUser?> GetByProviderUserAsync(Guid providerId, Guid userId)
{
using (var connection = new SqlConnection(ConnectionString))

View File

@@ -96,6 +96,20 @@ public class ProviderUserRepository :
return await query.ToArrayAsync();
}
}
public async Task<ICollection<ProviderUser>> GetManyByManyUsersAsync(IEnumerable<Guid> userIds)
{
await using var scope = ServiceScopeFactory.CreateAsyncScope();
var dbContext = GetDatabaseContext(scope);
var query = from pu in dbContext.ProviderUsers
where pu.UserId != null && userIds.Contains(pu.UserId.Value)
select pu;
return await query.ToArrayAsync();
}
public async Task<ProviderUser> GetByProviderUserAsync(Guid providerId, Guid userId)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[ProviderUser_ReadManyByManyUserIds]
@UserIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
[pu].*
FROM
[dbo].[ProviderUserView] AS [pu]
JOIN
@UserIds [u] ON [u].[Id] = [pu].[UserId]
END