From 07a18d31a978bdc2b3426c9e8616702e79fde257 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Thu, 30 Oct 2025 14:34:18 -0500 Subject: [PATCH 01/14] [PM-27594] - Update Org and License with Token (#6518) * Updating the license and org with claims when updating via license token. * Removing the fature flag check and adding a null check. * Added to method. --- src/Core/AdminConsole/Entities/Organization.cs | 1 + src/Core/AdminConsole/Services/OrganizationFactory.cs | 1 + .../Commands/UpdateOrganizationLicenseCommand.cs | 8 ++++++++ 3 files changed, 10 insertions(+) diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 4cbde4a61a..73aa162f22 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -333,5 +333,6 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable UseRiskInsights = license.UseRiskInsights; UseOrganizationDomains = license.UseOrganizationDomains; UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation; } } diff --git a/src/Core/AdminConsole/Services/OrganizationFactory.cs b/src/Core/AdminConsole/Services/OrganizationFactory.cs index 42d6e7c8d5..f5df3327b1 100644 --- a/src/Core/AdminConsole/Services/OrganizationFactory.cs +++ b/src/Core/AdminConsole/Services/OrganizationFactory.cs @@ -111,5 +111,6 @@ public static class OrganizationFactory UseRiskInsights = license.UseRiskInsights, UseOrganizationDomains = license.UseOrganizationDomains, UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation }; } diff --git a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs index fde95f2e70..1dfd786210 100644 --- a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs +++ b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs @@ -1,5 +1,7 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; @@ -52,6 +54,12 @@ public class UpdateOrganizationLicenseCommand : IUpdateOrganizationLicenseComman throw new BadRequestException(exception); } + var useAutomaticUserConfirmation = claimsPrincipal? + .GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation) ?? false; + + selfHostedOrganization.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + license.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + await WriteLicenseFileAsync(selfHostedOrganization, license); await UpdateOrganizationAsync(selfHostedOrganization, license); } From b8325414bf06d3395b45b8da4118f83ed27fe130 Mon Sep 17 00:00:00 2001 From: MtnBurrit0 <77340197+mimartin12@users.noreply.github.com> Date: Thu, 30 Oct 2025 13:55:28 -0600 Subject: [PATCH 02/14] Disable environment synchronization in workflow (#6525) --- .github/workflows/ephemeral-environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ephemeral-environment.yml b/.github/workflows/ephemeral-environment.yml index d85fcf2fd4..456ca573cc 100644 --- a/.github/workflows/ephemeral-environment.yml +++ b/.github/workflows/ephemeral-environment.yml @@ -16,5 +16,5 @@ jobs: with: project: server pull_request_number: ${{ github.event.number }} - sync_environment: true + sync_environment: false secrets: inherit From e102a7488e09b8f237b618707f1381961b4f00bc Mon Sep 17 00:00:00 2001 From: Vijay Oommen Date: Thu, 30 Oct 2025 16:54:05 -0500 Subject: [PATCH 03/14] [PM-26967] Added new metric properties (#6519) --- .../OrganizationReportsController.cs | 26 +++++----- .../OrganizationReportResponseModel.cs | 38 +++++++++++++++ .../Data/OrganizationReportMetricsData.cs | 48 +++++++++++++++++++ .../AddOrganizationReportCommand.cs | 18 ++++++- .../Requests/AddOrganizationReportRequest.cs | 15 +++--- .../OrganizationReportMetricsRequest.cs | 31 ++++++++++++ ...rganizationReportApplicationDataRequest.cs | 7 +-- .../UpdateOrganizationReportSummaryRequest.cs | 8 ++-- ...rganizationReportApplicationDataCommand.cs | 2 +- .../UpdateOrganizationReportSummaryCommand.cs | 4 +- .../IOrganizationReportRepository.cs | 4 ++ .../Dirt/OrganizationReportRepository.cs | 28 +++++++++++ .../OrganizationReportRepository.cs | 28 +++++++++++ .../OrganizationReport_UpdateMetrics.sql | 39 +++++++++++++++ .../OrganizationReportsControllerTests.cs | 19 +++++--- .../OrganizationReportRepositoryTests.cs | 44 +++++++++++++++++ ...30_00_OrganizationReport_UpdateMetrics.sql | 39 +++++++++++++++ 17 files changed, 359 insertions(+), 39 deletions(-) create mode 100644 src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs create mode 100644 src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs create mode 100644 src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs create mode 100644 src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql create mode 100644 util/Migrator/DbScripts/2025-10-30_00_OrganizationReport_UpdateMetrics.sql diff --git a/src/Api/Dirt/Controllers/OrganizationReportsController.cs b/src/Api/Dirt/Controllers/OrganizationReportsController.cs index bcd64b0bdf..fc9a1b2d84 100644 --- a/src/Api/Dirt/Controllers/OrganizationReportsController.cs +++ b/src/Api/Dirt/Controllers/OrganizationReportsController.cs @@ -1,4 +1,5 @@ -using Bit.Core.Context; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Exceptions; @@ -61,8 +62,9 @@ public class OrganizationReportsController : Controller } var latestReport = await _getOrganizationReportQuery.GetLatestOrganizationReportAsync(organizationId); + var response = latestReport == null ? null : new OrganizationReportResponseModel(latestReport); - return Ok(latestReport); + return Ok(response); } [HttpGet("{organizationId}/{reportId}")] @@ -102,7 +104,8 @@ public class OrganizationReportsController : Controller } var report = await _addOrganizationReportCommand.AddOrganizationReportAsync(request); - return Ok(report); + var response = report == null ? null : new OrganizationReportResponseModel(report); + return Ok(response); } [HttpPatch("{organizationId}/{reportId}")] @@ -119,7 +122,8 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportCommand.UpdateOrganizationReportAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + return Ok(response); } #endregion @@ -182,10 +186,10 @@ public class OrganizationReportsController : Controller { throw new BadRequestException("Report ID in the request body must match the route parameter"); } - var updatedReport = await _updateOrganizationReportSummaryCommand.UpdateOrganizationReportSummaryAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - return Ok(updatedReport); + return Ok(response); } #endregion @@ -228,7 +232,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportDataCommand.UpdateOrganizationReportDataAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + + return Ok(response); } #endregion @@ -265,7 +271,6 @@ public class OrganizationReportsController : Controller { try { - if (!await _currentContext.AccessReports(organizationId)) { throw new NotFoundException(); @@ -282,10 +287,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportApplicationDataCommand.UpdateOrganizationReportApplicationDataAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - - - return Ok(updatedReport); + return Ok(response); } catch (Exception ex) when (!(ex is BadRequestException || ex is NotFoundException)) { diff --git a/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs new file mode 100644 index 0000000000..e477e5b806 --- /dev/null +++ b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs @@ -0,0 +1,38 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Api.Dirt.Models.Response; + +public class OrganizationReportResponseModel +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public string? ReportData { get; set; } + public string? ContentEncryptionKey { get; set; } + public string? SummaryData { get; set; } + public string? ApplicationData { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public DateTime? CreationDate { get; set; } = null; + public DateTime? RevisionDate { get; set; } = null; + + public OrganizationReportResponseModel(OrganizationReport organizationReport) + { + if (organizationReport == null) + { + return; + } + + Id = organizationReport.Id; + OrganizationId = organizationReport.OrganizationId; + ReportData = organizationReport.ReportData; + ContentEncryptionKey = organizationReport.ContentEncryptionKey; + SummaryData = organizationReport.SummaryData; + ApplicationData = organizationReport.ApplicationData; + PasswordCount = organizationReport.PasswordCount; + PasswordAtRiskCount = organizationReport.PasswordAtRiskCount; + MemberCount = organizationReport.MemberCount; + CreationDate = organizationReport.CreationDate; + RevisionDate = organizationReport.RevisionDate; + } +} diff --git a/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs new file mode 100644 index 0000000000..ffef91275a --- /dev/null +++ b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs @@ -0,0 +1,48 @@ +using Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +namespace Bit.Core.Dirt.Reports.Models.Data; + +public class OrganizationReportMetricsData +{ + public Guid OrganizationId { get; set; } + public int? ApplicationCount { get; set; } + public int? ApplicationAtRiskCount { get; set; } + public int? CriticalApplicationCount { get; set; } + public int? CriticalApplicationAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public int? MemberAtRiskCount { get; set; } + public int? CriticalMemberCount { get; set; } + public int? CriticalMemberAtRiskCount { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? CriticalPasswordCount { get; set; } + public int? CriticalPasswordAtRiskCount { get; set; } + + public static OrganizationReportMetricsData From(Guid organizationId, OrganizationReportMetricsRequest? request) + { + if (request == null) + { + return new OrganizationReportMetricsData + { + OrganizationId = organizationId + }; + } + + return new OrganizationReportMetricsData + { + OrganizationId = organizationId, + ApplicationCount = request.ApplicationCount, + ApplicationAtRiskCount = request.ApplicationAtRiskCount, + CriticalApplicationCount = request.CriticalApplicationCount, + CriticalApplicationAtRiskCount = request.CriticalApplicationAtRiskCount, + MemberCount = request.MemberCount, + MemberAtRiskCount = request.MemberAtRiskCount, + CriticalMemberCount = request.CriticalMemberCount, + CriticalMemberAtRiskCount = request.CriticalMemberAtRiskCount, + PasswordCount = request.PasswordCount, + PasswordAtRiskCount = request.PasswordAtRiskCount, + CriticalPasswordCount = request.CriticalPasswordCount, + CriticalPasswordAtRiskCount = request.CriticalPasswordAtRiskCount + }; + } +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs index f0477806d8..236560487e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs @@ -35,14 +35,28 @@ public class AddOrganizationReportCommand : IAddOrganizationReportCommand throw new BadRequestException(errorMessage); } + var requestMetrics = request.Metrics ?? new OrganizationReportMetricsRequest(); + var organizationReport = new OrganizationReport { OrganizationId = request.OrganizationId, - ReportData = request.ReportData, + ReportData = request.ReportData ?? string.Empty, CreationDate = DateTime.UtcNow, - ContentEncryptionKey = request.ContentEncryptionKey, + ContentEncryptionKey = request.ContentEncryptionKey ?? string.Empty, SummaryData = request.SummaryData, ApplicationData = request.ApplicationData, + ApplicationCount = requestMetrics.ApplicationCount, + ApplicationAtRiskCount = requestMetrics.ApplicationAtRiskCount, + CriticalApplicationCount = requestMetrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = requestMetrics.CriticalApplicationAtRiskCount, + MemberCount = requestMetrics.MemberCount, + MemberAtRiskCount = requestMetrics.MemberAtRiskCount, + CriticalMemberCount = requestMetrics.CriticalMemberCount, + CriticalMemberAtRiskCount = requestMetrics.CriticalMemberAtRiskCount, + PasswordCount = requestMetrics.PasswordCount, + PasswordAtRiskCount = requestMetrics.PasswordAtRiskCount, + CriticalPasswordCount = requestMetrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = requestMetrics.CriticalPasswordAtRiskCount, RevisionDate = DateTime.UtcNow }; diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs index 2a8c0203f9..eecc84c522 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs @@ -1,16 +1,15 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class AddOrganizationReportRequest { public Guid OrganizationId { get; set; } - public string ReportData { get; set; } + public string? ReportData { get; set; } - public string ContentEncryptionKey { get; set; } + public string? ContentEncryptionKey { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } + + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs new file mode 100644 index 0000000000..9403a5f1c2 --- /dev/null +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs @@ -0,0 +1,31 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +public class OrganizationReportMetricsRequest +{ + [JsonPropertyName("totalApplicationCount")] + public int? ApplicationCount { get; set; } = null; + [JsonPropertyName("totalAtRiskApplicationCount")] + public int? ApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalApplicationCount")] + public int? CriticalApplicationCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskApplicationCount")] + public int? CriticalApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalMemberCount")] + public int? MemberCount { get; set; } = null; + [JsonPropertyName("totalAtRiskMemberCount")] + public int? MemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalMemberCount")] + public int? CriticalMemberCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskMemberCount")] + public int? CriticalMemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalPasswordCount")] + public int? PasswordCount { get; set; } = null; + [JsonPropertyName("totalAtRiskPasswordCount")] + public int? PasswordAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalPasswordCount")] + public int? CriticalPasswordCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskPasswordCount")] + public int? CriticalPasswordAtRiskCount { get; set; } = null; +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs index ab4fcc5921..e549a3f120 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs @@ -1,11 +1,8 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportApplicationDataRequest { public Guid Id { get; set; } public Guid OrganizationId { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs index b0e555fcef..27358537c2 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs @@ -1,11 +1,9 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportSummaryRequest { public Guid OrganizationId { get; set; } public Guid ReportId { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs index 67ec49d004..375b766a0e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs @@ -53,7 +53,7 @@ public class UpdateOrganizationReportApplicationDataCommand : IUpdateOrganizatio throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData); + var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report application data {reportId} for organization {organizationId}", request.Id, request.OrganizationId); diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs index 6859814d65..5d0f2670ca 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Dirt.Repositories; @@ -53,7 +54,8 @@ public class UpdateOrganizationReportSummaryCommand : IUpdateOrganizationReportS throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData); + await _organizationReportRepo.UpdateMetricsAsync(request.ReportId, OrganizationReportMetricsData.From(request.OrganizationId, request.Metrics)); + var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report summary {reportId} for organization {organizationId}", request.ReportId, request.OrganizationId); diff --git a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs index 9687173716..b4c2f90566 100644 --- a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs +++ b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs @@ -1,5 +1,6 @@ using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Repositories; namespace Bit.Core.Dirt.Repositories; @@ -21,5 +22,8 @@ public interface IOrganizationReportRepository : IRepository GetApplicationDataAsync(Guid reportId); Task UpdateApplicationDataAsync(Guid orgId, Guid reportId, string applicationData); + + // Metrics methods + Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics); } diff --git a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs index 3d001cce92..c704a208d1 100644 --- a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs +++ b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs @@ -4,6 +4,7 @@ using System.Data; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.Dapper.Repositories; @@ -173,4 +174,31 @@ public class OrganizationReportRepository : Repository commandType: CommandType.StoredProcedure); } } + + public async Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using var connection = new SqlConnection(ConnectionString); + var parameters = new + { + Id = reportId, + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }; + + await connection.ExecuteAsync( + $"[{Schema}].[OrganizationReport_UpdateMetrics]", + parameters, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs index 525c5a479d..d08e70c353 100644 --- a/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs @@ -4,6 +4,7 @@ using AutoMapper; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Infrastructure.EntityFramework.Repositories; using LinqToDB; @@ -184,4 +185,31 @@ public class OrganizationReportRepository : return Mapper.Map(updatedReport); } } + + public Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + return dbContext.OrganizationReports + .Where(p => p.Id == reportId) + .UpdateAsync(p => new Models.OrganizationReport + { + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }); + } + } } diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql new file mode 100644 index 0000000000..8b06c90fe1 --- /dev/null +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql @@ -0,0 +1,39 @@ +CREATE PROCEDURE [dbo].[OrganizationReport_UpdateMetrics] + @Id UNIQUEIDENTIFIER, + @ApplicationCount INT, + @ApplicationAtRiskCount INT, + @CriticalApplicationCount INT, + @CriticalApplicationAtRiskCount INT, + @MemberCount INT, + @MemberAtRiskCount INT, + @CriticalMemberCount INT, + @CriticalMemberAtRiskCount INT, + @PasswordCount INT, + @PasswordAtRiskCount INT, + @CriticalPasswordCount INT, + @CriticalPasswordAtRiskCount INT, + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + UPDATE + [dbo].[OrganizationReport] + SET + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount, + [RevisionDate] = @RevisionDate + WHERE + [Id] = @Id + +END diff --git a/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs b/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs index c786fd1c1b..880be1e4d9 100644 --- a/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs +++ b/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs @@ -1,4 +1,5 @@ using Bit.Api.Dirt.Controllers; +using Bit.Api.Dirt.Models.Response; using Bit.Core.Context; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; @@ -39,7 +40,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -262,7 +264,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -365,7 +368,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -597,7 +601,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -812,7 +817,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -1050,7 +1056,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] diff --git a/test/Infrastructure.EFIntegration.Test/Dirt/Repositories/OrganizationReportRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Dirt/Repositories/OrganizationReportRepositoryTests.cs index 7a1d6c5545..f2613fd241 100644 --- a/test/Infrastructure.EFIntegration.Test/Dirt/Repositories/OrganizationReportRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Dirt/Repositories/OrganizationReportRepositoryTests.cs @@ -1,6 +1,7 @@ using AutoFixture; using Bit.Core.AdminConsole.Entities; using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Core.Repositories; using Bit.Core.Test.AutoFixture.Attributes; @@ -489,6 +490,49 @@ public class OrganizationReportRepositoryTests Assert.Null(result); } + [CiSkippedTheory, EfOrganizationReportAutoData] + public async Task UpdateMetricsAsync_ShouldUpdateMetricsCorrectly( + OrganizationReportRepository sqlOrganizationReportRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) + { + // Arrange + var (org, report) = await CreateOrganizationAndReportAsync(sqlOrganizationRepo, sqlOrganizationReportRepo); + var metrics = new OrganizationReportMetricsData + { + ApplicationCount = 10, + ApplicationAtRiskCount = 2, + CriticalApplicationCount = 5, + CriticalApplicationAtRiskCount = 1, + MemberCount = 20, + MemberAtRiskCount = 4, + CriticalMemberCount = 10, + CriticalMemberAtRiskCount = 2, + PasswordCount = 100, + PasswordAtRiskCount = 15, + CriticalPasswordCount = 50, + CriticalPasswordAtRiskCount = 5 + }; + + // Act + await sqlOrganizationReportRepo.UpdateMetricsAsync(report.Id, metrics); + var updatedReport = await sqlOrganizationReportRepo.GetByIdAsync(report.Id); + + // Assert + Assert.Equal(metrics.ApplicationCount, updatedReport.ApplicationCount); + Assert.Equal(metrics.ApplicationAtRiskCount, updatedReport.ApplicationAtRiskCount); + Assert.Equal(metrics.CriticalApplicationCount, updatedReport.CriticalApplicationCount); + Assert.Equal(metrics.CriticalApplicationAtRiskCount, updatedReport.CriticalApplicationAtRiskCount); + Assert.Equal(metrics.MemberCount, updatedReport.MemberCount); + Assert.Equal(metrics.MemberAtRiskCount, updatedReport.MemberAtRiskCount); + Assert.Equal(metrics.CriticalMemberCount, updatedReport.CriticalMemberCount); + Assert.Equal(metrics.CriticalMemberAtRiskCount, updatedReport.CriticalMemberAtRiskCount); + Assert.Equal(metrics.PasswordCount, updatedReport.PasswordCount); + Assert.Equal(metrics.PasswordAtRiskCount, updatedReport.PasswordAtRiskCount); + Assert.Equal(metrics.CriticalPasswordCount, updatedReport.CriticalPasswordCount); + Assert.Equal(metrics.CriticalPasswordAtRiskCount, updatedReport.CriticalPasswordAtRiskCount); + } + + private async Task<(Organization, OrganizationReport)> CreateOrganizationAndReportAsync( IOrganizationRepository orgRepo, IOrganizationReportRepository orgReportRepo) diff --git a/util/Migrator/DbScripts/2025-10-30_00_OrganizationReport_UpdateMetrics.sql b/util/Migrator/DbScripts/2025-10-30_00_OrganizationReport_UpdateMetrics.sql new file mode 100644 index 0000000000..b07481f876 --- /dev/null +++ b/util/Migrator/DbScripts/2025-10-30_00_OrganizationReport_UpdateMetrics.sql @@ -0,0 +1,39 @@ +CREATE OR ALTER PROCEDURE [dbo].[OrganizationReport_UpdateMetrics] + @Id UNIQUEIDENTIFIER, + @ApplicationCount INT, + @ApplicationAtRiskCount INT, + @CriticalApplicationCount INT, + @CriticalApplicationAtRiskCount INT, + @MemberCount INT, + @MemberAtRiskCount INT, + @CriticalMemberCount INT, + @CriticalMemberAtRiskCount INT, + @PasswordCount INT, + @PasswordAtRiskCount INT, + @CriticalPasswordCount INT, + @CriticalPasswordAtRiskCount INT, + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + UPDATE + [dbo].[OrganizationReport] + SET + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount, + [RevisionDate] = @RevisionDate + WHERE + [Id] = @Id + +END \ No newline at end of file From 410e754cd9295327de386b46607b5db8a6e7a2c2 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:37:01 -0500 Subject: [PATCH 04/14] [PM-27553] Resolve premium purchase for user with account credit that used payment method (#6514) * Update payment method for customer purchasing premium who has account credit but used a payment method * Claude feedback + dotnet run format --- .../Billing/Payment/Models/PaymentMethod.cs | 2 + ...tePremiumCloudHostedSubscriptionCommand.cs | 104 +++++++++++------- ...miumCloudHostedSubscriptionCommandTests.cs | 81 +++++++++++++- 3 files changed, 145 insertions(+), 42 deletions(-) diff --git a/src/Core/Billing/Payment/Models/PaymentMethod.cs b/src/Core/Billing/Payment/Models/PaymentMethod.cs index a6835f9a32..b0733da414 100644 --- a/src/Core/Billing/Payment/Models/PaymentMethod.cs +++ b/src/Core/Billing/Payment/Models/PaymentMethod.cs @@ -11,7 +11,9 @@ public class PaymentMethod(OneOf new(tokenized); public static implicit operator PaymentMethod(NonTokenizedPaymentMethod nonTokenized) => new(nonTokenized); public bool IsTokenized => IsT0; + public TokenizedPaymentMethod AsTokenized => AsT0; public bool IsNonTokenized => IsT1; + public NonTokenizedPaymentMethod AsNonTokenized => AsT1; } internal class PaymentMethodJsonConverter : JsonConverter diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 3b2ac5343f..1f752a007b 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -2,7 +2,9 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; @@ -21,6 +23,7 @@ using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Premium.Commands; +using static StripeConstants; using static Utilities; /// @@ -32,7 +35,7 @@ public interface ICreatePremiumCloudHostedSubscriptionCommand /// /// Creates a premium cloud-hosted subscription for the specified user. /// - /// The user to create the premium subscription for. Must not already be a premium user. + /// The user to create the premium subscription for. Must not yet be a premium user. /// The tokenized payment method containing the payment type and token for billing. /// The billing address information required for tax calculation and customer creation. /// Additional storage in GB beyond the base 1GB included with premium (must be >= 0). @@ -53,7 +56,9 @@ public class CreatePremiumCloudHostedSubscriptionCommand( IUserService userService, IPushNotificationService pushNotificationService, ILogger logger, - IPricingClient pricingClient) + IPricingClient pricingClient, + IHasPaymentMethodQuery hasPaymentMethodQuery, + IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingCommand(logger), ICreatePremiumCloudHostedSubscriptionCommand { private static readonly List _expand = ["tax"]; @@ -75,10 +80,30 @@ public class CreatePremiumCloudHostedSubscriptionCommand( return new BadRequest("Additional storage must be greater than 0."); } - // Note: A customer will already exist if the customer has purchased account credits. - var customer = string.IsNullOrEmpty(user.GatewayCustomerId) - ? await CreateCustomerAsync(user, paymentMethod, billingAddress) - : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + Customer? customer; + + /* + * For a new customer purchasing a new subscription, we attach the payment method while creating the customer. + */ + if (string.IsNullOrEmpty(user.GatewayCustomerId)) + { + customer = await CreateCustomerAsync(user, paymentMethod, billingAddress); + } + /* + * An existing customer without a payment method starting a new subscription indicates a user who previously + * purchased account credit but chose to use a tokenizable payment method to pay for the subscription. In this case, + * we need to add the payment method to their customer first. If the incoming payment method is account credit, + * we can just go straight to fetching the customer since there's no payment method to apply. + */ + else if (paymentMethod.IsTokenized && !await hasPaymentMethodQuery.Run(user)) + { + await updatePaymentMethodCommand.Run(user, paymentMethod.AsTokenized, billingAddress); + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } + else + { + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } customer = await ReconcileBillingLocationAsync(customer, billingAddress); @@ -91,9 +116,9 @@ public class CreatePremiumCloudHostedSubscriptionCommand( switch (tokenized) { case { Type: TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete: + when subscription.Status == SubscriptionStatus.Incomplete: case { Type: not TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Active: + when subscription.Status == SubscriptionStatus.Active: { user.Premium = true; user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); @@ -101,13 +126,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } } }, - nonTokenized => + _ => { - if (subscription.Status == StripeConstants.SubscriptionStatus.Active) + if (subscription.Status != SubscriptionStatus.Active) { - user.Premium = true; - user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + return; } + + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); }); user.Gateway = GatewayType.Stripe; @@ -163,25 +190,25 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, - [StripeConstants.MetadataKeys.UserId] = user.Id.ToString() + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [MetadataKeys.UserId] = user.Id.ToString() }, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; var braintreeCustomerId = ""; // We have checked that the payment method is tokenized, so we can safely cast it. - // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault - switch (paymentMethod.AsT0.Type) + var tokenizedPaymentMethod = paymentMethod.AsTokenized; + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.AsT0.Token })) + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -195,19 +222,19 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } case TokenizablePaymentMethodType.Card: { - customerCreateOptions.PaymentMethod = paymentMethod.AsT0.Token; - customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.AsT0.Token; + customerCreateOptions.PaymentMethod = tokenizedPaymentMethod.Token; + customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = tokenizedPaymentMethod.Token; break; } case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.AsT0.Token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, tokenizedPaymentMethod.Token); customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } default: { - _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.AsT0.Type.ToString()); + _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, tokenizedPaymentMethod.Type.ToString()); throw new BillingException(); } } @@ -225,21 +252,18 @@ public class CreatePremiumCloudHostedSubscriptionCommand( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - if (paymentMethod.IsTokenized) + switch (tokenizedPaymentMethod.Type) { - switch (paymentMethod.AsT0.Type) - { - case TokenizablePaymentMethodType.BankAccount: - { - await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id); - break; - } - case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): - { - await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); - break; - } - } + case TokenizablePaymentMethodType.BankAccount: + { + await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id); + break; + } + case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + { + await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); + break; + } } } } @@ -271,7 +295,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( Expand = _expand, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); @@ -310,15 +334,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { Enabled = true }, - CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + CollectionMethod = CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.UserId] = userId.ToString() + [MetadataKeys.UserId] = userId.ToString() }, PaymentBehavior = usingPayPal - ? StripeConstants.PaymentBehavior.DefaultIncomplete + ? PaymentBehavior.DefaultIncomplete : null, OffSession = true }; diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index c0618f78ed..493246c578 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -2,7 +2,9 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; @@ -34,6 +36,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests private readonly IUserService _userService = Substitute.For(); private readonly IPushNotificationService _pushNotificationService = Substitute.For(); private readonly IPricingClient _pricingClient = Substitute.For(); + private readonly IHasPaymentMethodQuery _hasPaymentMethodQuery = Substitute.For(); + private readonly IUpdatePaymentMethodCommand _updatePaymentMethodCommand = Substitute.For(); private readonly CreatePremiumCloudHostedSubscriptionCommand _command; public CreatePremiumCloudHostedSubscriptionCommandTests() @@ -62,7 +66,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests _userService, _pushNotificationService, Substitute.For>(), - _pricingClient); + _pricingClient, + _hasPaymentMethodQuery, + _updatePaymentMethodCommand); } [Theory, BitAutoData] @@ -314,7 +320,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests } [Theory, BitAutoData] - public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer( + public async Task Run_UserHasExistingGatewayCustomerIdAndPaymentMethod_UsesExistingCustomer( User user, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) @@ -347,6 +353,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); + // Mock that the user has a payment method (this is the key difference from the credit purchase case) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(true); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); @@ -358,6 +366,75 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_UserPreviouslyPurchasedCreditWithoutPaymentMethod_UpdatesPaymentMethodAndCreatesSubscription( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = "existing_customer_123"; // Customer exists from previous credit purchase + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "existing_customer_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + var mockInvoice = Substitute.For(); + MaskedPaymentMethod mockMaskedPaymentMethod = new MaskedCard + { + Brand = "visa", + Last4 = "1234", + Expiration = "12/2025" + }; + + // Mock that the user does NOT have a payment method (simulating credit purchase scenario) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(false); + _updatePaymentMethodCommand.Run(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(mockMaskedPaymentMethod); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + // Verify that update payment method was called (new behavior for credit purchase case) + await _updatePaymentMethodCommand.Received(1).Run(user, paymentMethod, billingAddress); + // Verify GetCustomerOrThrow was called after updating payment method + await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); + // Verify no new customer was created + await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + // Verify subscription was created + await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + // Verify user was updated correctly + Assert.True(user.Premium); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } [Theory, BitAutoData] From d40d705aac07c2d80fb272d0376a7c509211760e Mon Sep 17 00:00:00 2001 From: Daniel James Smith <2670567+djsmith85@users.noreply.github.com> Date: Fri, 31 Oct 2025 18:40:54 +0100 Subject: [PATCH 05/14] Revert feature flag removal for Chromium importers (#6526) Co-authored-by: Daniel James Smith --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index aa1f1c934b..204a8e9d67 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -229,6 +229,7 @@ public static class FeatureFlagKeys /* Tools Team */ public const string DesktopSendUIRefresh = "desktop-send-ui-refresh"; public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators"; + public const string UseChromiumImporter = "pm-23982-chromium-importer"; public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe"; /* Vault Team */ From 21cc0b38b0944a0efa11f88135e7f9e56898676c Mon Sep 17 00:00:00 2001 From: Jimmy Vo Date: Fri, 31 Oct 2025 14:47:22 -0400 Subject: [PATCH 06/14] [PM-26401] Add logging logic (#6523) --- .../Controllers/EventsController.cs | 15 +- .../Public/Controllers/EventsController.cs | 12 +- .../DiagnosticTools/EventDiagnosticLogger.cs | 87 +++++++ src/Core/Constants.cs | 1 + .../EventDiagnosticLoggerTests.cs | 221 ++++++++++++++++++ 5 files changed, 332 insertions(+), 4 deletions(-) create mode 100644 src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs create mode 100644 test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs diff --git a/src/Api/AdminConsole/Controllers/EventsController.cs b/src/Api/AdminConsole/Controllers/EventsController.cs index f868f0b3b6..7e058a7870 100644 --- a/src/Api/AdminConsole/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Controllers/EventsController.cs @@ -3,6 +3,7 @@ using Bit.Api.Models.Response; using Bit.Api.Utilities; +using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Context; using Bit.Core.Enums; @@ -31,10 +32,11 @@ public class EventsController : Controller private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; - public EventsController( - IUserService userService, + public EventsController(IUserService userService, ICipherRepository cipherRepository, IOrganizationUserRepository organizationUserRepository, IProviderUserRepository providerUserRepository, @@ -42,7 +44,9 @@ public class EventsController : Controller ICurrentContext currentContext, ISecretRepository secretRepository, IProjectRepository projectRepository, - IServiceAccountRepository serviceAccountRepository) + IServiceAccountRepository serviceAccountRepository, + ILogger logger, + IFeatureService featureService) { _userService = userService; _cipherRepository = cipherRepository; @@ -53,6 +57,8 @@ public class EventsController : Controller _secretRepository = secretRepository; _projectRepository = projectRepository; _serviceAccountRepository = serviceAccountRepository; + _logger = logger; + _featureService = featureService; } [HttpGet("")] @@ -114,6 +120,9 @@ public class EventsController : Controller var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = continuationToken }); var responses = result.Data.Select(e => new EventResponseModel(e)); + + _logger.LogAggregateData(_featureService, orgId, responses, continuationToken, start, end); + return new ListResponseModel(responses, result.ContinuationToken); } diff --git a/src/Api/AdminConsole/Public/Controllers/EventsController.cs b/src/Api/AdminConsole/Public/Controllers/EventsController.cs index 3dd55d51e2..19edbdd5a6 100644 --- a/src/Api/AdminConsole/Public/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Public/Controllers/EventsController.cs @@ -4,9 +4,11 @@ using System.Net; using Bit.Api.Models.Public.Request; using Bit.Api.Models.Public.Response; +using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.Context; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Vault.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -20,15 +22,21 @@ public class EventsController : Controller private readonly IEventRepository _eventRepository; private readonly ICipherRepository _cipherRepository; private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; public EventsController( IEventRepository eventRepository, ICipherRepository cipherRepository, - ICurrentContext currentContext) + ICurrentContext currentContext, + ILogger logger, + IFeatureService featureService) { _eventRepository = eventRepository; _cipherRepository = cipherRepository; _currentContext = currentContext; + _logger = logger; + _featureService = featureService; } /// @@ -69,6 +77,8 @@ public class EventsController : Controller var eventResponses = result.Data.Select(e => new EventResponseModel(e)); var response = new PagedListResponseModel(eventResponses, result.ContinuationToken); + + _logger.LogAggregateData(_featureService, _currentContext.OrganizationId!.Value, response, request); return new JsonResult(response); } } diff --git a/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs new file mode 100644 index 0000000000..9f6a8d2639 --- /dev/null +++ b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs @@ -0,0 +1,87 @@ +using Bit.Api.Models.Public.Request; +using Bit.Api.Models.Public.Response; +using Bit.Core; +using Bit.Core.Services; + +namespace Bit.Api.Utilities.DiagnosticTools; + +public static class EventDiagnosticLogger +{ + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + PagedListResponseModel data, EventFilterRequestModel request) + { + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.Data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(data.ContinuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd} ActingUserId:{ActingUserId} ItemId:{ItemId},", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + request.Start?.ToString("o"), + request.End?.ToString("o"), + request.ActingUserId, + request.ItemId); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } + + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + IEnumerable data, + string? continuationToken, + DateTime? queryStart = null, + DateTime? queryEnd = null) + { + + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(continuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd}", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + queryStart?.ToString("o"), + queryEnd?.ToString("o")); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } +} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 204a8e9d67..d147f3908b 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -252,6 +252,7 @@ public static class FeatureFlagKeys /* DIRT Team */ public const string PM22887_RiskInsightsActivityTab = "pm-22887-risk-insights-activity-tab"; public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike"; + public const string EventDiagnosticLogging = "pm-27666-siem-event-log-debugging"; public static List GetAllKeys() { diff --git a/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs b/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs new file mode 100644 index 0000000000..ada75b148b --- /dev/null +++ b/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs @@ -0,0 +1,221 @@ +using Bit.Api.Models.Public.Request; +using Bit.Api.Models.Public.Response; +using Bit.Api.Utilities.DiagnosticTools; +using Bit.Core; +using Bit.Core.Models.Data; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Utilities.DiagnosticTools; + +public class EventDiagnosticLoggerTests +{ + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_FeatureFlagEnabled_LogsInformation( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var request = new EventFilterRequestModel() + { + Start = DateTime.UtcNow.AddMinutes(-3), + End = DateTime.UtcNow, + ActingUserId = Guid.NewGuid(), + ItemId = Guid.NewGuid(), + }; + + var newestEvent = Substitute.For(); + newestEvent.Date.Returns(DateTime.UtcNow); + var middleEvent = Substitute.For(); + middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1)); + var oldestEvent = Substitute.For(); + oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-3)); + + var eventResponses = new List + { + new (newestEvent), + new (middleEvent), + new (oldestEvent) + }; + var response = new PagedListResponseModel(eventResponses, "continuation-token"); + + // Act + logger.LogAggregateData(featureService, organizationId, response, request); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains($"Event count:{eventResponses.Count}") && + o.ToString().Contains($"newest record:{newestEvent.Date:O}") && + o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") && + o.ToString().Contains("HasMore:True") && + o.ToString().Contains($"Start:{request.Start:o}") && + o.ToString().Contains($"End:{request.End:o}") && + o.ToString().Contains($"ActingUserId:{request.ActingUserId}") && + o.ToString().Contains($"ItemId:{request.ItemId}")) + , + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_FeatureFlagDisabled_DoesNotLog( + Guid organizationId, + EventFilterRequestModel request) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false); + + PagedListResponseModel dummy = null; + + // Act + logger.LogAggregateData(featureService, organizationId, dummy, request); + + // Assert + logger.DidNotReceive().Log( + LogLevel.Information, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_EmptyData_LogsZeroCount( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var request = new EventFilterRequestModel() + { + Start = null, + End = null, + ActingUserId = null, + ItemId = null, + ContinuationToken = null, + }; + var response = new PagedListResponseModel(new List(), null); + + // Act + logger.LogAggregateData(featureService, organizationId, response, request); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains("Event count:0") && + o.ToString().Contains("HasMore:False")), + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_FeatureFlagDisabled_DoesNotLog(Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false); + + + // Act + logger.LogAggregateData(featureService, organizationId, null, null, null, null); + + // Assert + logger.DidNotReceive().Log( + LogLevel.Information, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_EmptyData_LogsZeroCount( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + Bit.Api.Models.Response.EventResponseModel[] emptyEvents = []; + + // Act + logger.LogAggregateData(featureService, organizationId, emptyEvents, null, null, null); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains("Event count:0") && + o.ToString().Contains("HasMore:False")), + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_FeatureFlagEnabled_LogsInformation( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var newestEvent = Substitute.For(); + newestEvent.Date.Returns(DateTime.UtcNow); + var middleEvent = Substitute.For(); + middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1)); + var oldestEvent = Substitute.For(); + oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-2)); + + var events = new List + { + new (newestEvent), + new (middleEvent), + new (oldestEvent) + }; + + var queryStart = DateTime.UtcNow.AddMinutes(-3); + var queryEnd = DateTime.UtcNow; + const string continuationToken = "continuation-token"; + + // Act + logger.LogAggregateData(featureService, organizationId, events, continuationToken, queryStart, queryEnd); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains($"Event count:{events.Count}") && + o.ToString().Contains($"newest record:{newestEvent.Date:O}") && + o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") && + o.ToString().Contains("HasMore:True") && + o.ToString().Contains($"Start:{queryStart:o}") && + o.ToString().Contains($"End:{queryEnd:o}")) + , + null, + Arg.Any>()); + } +} From 09564947e8fd759a73e1bd315eeaff24f8d58ad9 Mon Sep 17 00:00:00 2001 From: Github Actions Date: Fri, 31 Oct 2025 21:38:53 +0000 Subject: [PATCH 07/14] Bumped version to 2025.10.2 --- Directory.Build.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Directory.Build.props b/Directory.Build.props index 84b8dd22be..f14574a13c 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.10.1 + 2025.10.2 Bit.$(MSBuildProjectName) enable From e11458196c7f649091f2e6c896a9b9bca9c8e856 Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Sat, 1 Nov 2025 07:55:25 +1000 Subject: [PATCH 08/14] [PM-24192] Move account recovery logic to command (#6184) * Move account recovery logic to command (temporarily duplicated behind feature flag) * Move permission checks to authorization handler * Prevent user from recovering provider member account unless they are also provider member --- ...uthorizationHandlerCollectionExtensions.cs | 9 +- .../RecoverAccountAuthorizationHandler.cs | 110 +++++++ .../OrganizationUsersController.cs | 57 +++- .../AdminRecoverAccountCommand.cs | 79 +++++ .../IAdminRecoverAccountCommand.cs | 24 ++ src/Core/Constants.cs | 1 + src/Core/Context/ICurrentContext.cs | 24 +- ...OrganizationServiceCollectionExtensions.cs | 2 + src/Core/Services/IMailService.cs | 6 +- .../Implementations/HandlebarsMailService.cs | 2 +- .../NoopImplementations/NoopMailService.cs | 2 +- ...ionUsersControllerPutResetPasswordTests.cs | 197 ++++++++++++ ...RecoverAccountAuthorizationHandlerTests.cs | 296 ++++++++++++++++++ .../OrganizationUsersControllerTests.cs | 154 +++++++++ .../CurrentContextOrganizationFixtures.cs | 21 +- .../AdminRecoverAccountCommandTests.cs | 296 ++++++++++++++++++ 16 files changed, 1261 insertions(+), 19 deletions(-) create mode 100644 src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs create mode 100644 test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs create mode 100644 test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs create mode 100644 test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs diff --git a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs index ed628105e0..233dc138a6 100644 --- a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs +++ b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs @@ -13,9 +13,10 @@ public static class AuthorizationHandlerCollectionExtensions services.TryAddEnumerable([ ServiceDescriptor.Scoped(), - ServiceDescriptor.Scoped(), - ServiceDescriptor.Scoped(), - ServiceDescriptor.Scoped(), - ]); + ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), + ]); } } diff --git a/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs new file mode 100644 index 0000000000..239148ab25 --- /dev/null +++ b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs @@ -0,0 +1,110 @@ +using System.Security.Claims; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Microsoft.AspNetCore.Authorization; + +namespace Bit.Api.AdminConsole.Authorization; + +/// +/// An authorization requirement for recovering an organization member's account. +/// +/// +/// Note: this is different to simply being able to manage account recovery. The user must be recovering +/// a member who has equal or lesser permissions than them. +/// +public class RecoverAccountAuthorizationRequirement : IAuthorizationRequirement; + +/// +/// Authorizes members and providers to recover a target OrganizationUser's account. +/// +/// +/// This prevents privilege escalation by ensuring that a user cannot recover the account of +/// another user with a higher role or with provider membership. +/// +public class RecoverAccountAuthorizationHandler( + IOrganizationContext organizationContext, + ICurrentContext currentContext, + IProviderUserRepository providerUserRepository) + : AuthorizationHandler +{ + public const string FailureReason = "You are not permitted to recover this user's account."; + public const string ProviderFailureReason = "You are not permitted to recover a Provider member's account."; + + protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, + RecoverAccountAuthorizationRequirement requirement, + OrganizationUser targetOrganizationUser) + { + // Step 1: check that the User has permissions with respect to the organization. + // This may come from their role in the organization or their provider relationship. + var canRecoverOrganizationMember = + AuthorizeMember(context.User, targetOrganizationUser) || + await AuthorizeProviderAsync(context.User, targetOrganizationUser); + + if (!canRecoverOrganizationMember) + { + context.Fail(new AuthorizationFailureReason(this, FailureReason)); + return; + } + + // Step 2: check that the User has permissions with respect to any provider the target user is a member of. + // This prevents an organization admin performing privilege escalation into an unrelated provider. + var canRecoverProviderMember = await CanRecoverProviderAsync(targetOrganizationUser); + if (!canRecoverProviderMember) + { + context.Fail(new AuthorizationFailureReason(this, ProviderFailureReason)); + return; + } + + context.Succeed(requirement); + } + + private async Task AuthorizeProviderAsync(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + return await organizationContext.IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId); + } + + private bool AuthorizeMember(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = organizationContext.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId); + if (currentContextOrganization == null) + { + return false; + } + + // Current user must have equal or greater permissions than the user account being recovered + var authorized = targetOrganizationUser.Type switch + { + OrganizationUserType.Owner => currentContextOrganization.Type is OrganizationUserType.Owner, + OrganizationUserType.Admin => currentContextOrganization.Type is OrganizationUserType.Owner or OrganizationUserType.Admin, + _ => currentContextOrganization is + { Type: OrganizationUserType.Owner or OrganizationUserType.Admin } + or { Type: OrganizationUserType.Custom, Permissions.ManageResetPassword: true } + }; + + return authorized; + } + + private async Task CanRecoverProviderAsync(OrganizationUser targetOrganizationUser) + { + if (!targetOrganizationUser.UserId.HasValue) + { + // If an OrganizationUser is not linked to a User then it can't be linked to a Provider either. + // This is invalid but does not pose a privilege escalation risk. Return early and let the command + // handle the invalid input. + return true; + } + + var targetUserProviderUsers = + await providerUserRepository.GetManyByUserAsync(targetOrganizationUser.UserId.Value); + + // If the target user belongs to any provider that the current user is not a member of, + // deny the action to prevent privilege escalation from organization to provider. + // Note: we do not expect that a user is a member of more than 1 provider, but there is also no guarantee + // against it; this returns a sequence, so we handle the possibility. + var authorized = targetUserProviderUsers.All(providerUser => currentContext.ProviderUser(providerUser.ProviderId)); + return authorized; + } +} + diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 74ac9b1255..4b9f7e5d71 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below +// NOTE: This file is partially migrated to nullable reference types. Remove inline #nullable directives when addressing the FIXME. #nullable disable using Bit.Api.AdminConsole.Authorization; @@ -11,6 +12,7 @@ using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -70,6 +72,7 @@ public class OrganizationUsersController : Controller private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; private readonly IInitPendingOrganizationCommand _initPendingOrganizationCommand; private readonly IRevokeOrganizationUserCommand _revokeOrganizationUserCommand; + private readonly IAdminRecoverAccountCommand _adminRecoverAccountCommand; public OrganizationUsersController(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, @@ -97,7 +100,8 @@ public class OrganizationUsersController : Controller IRestoreOrganizationUserCommand restoreOrganizationUserCommand, IInitPendingOrganizationCommand initPendingOrganizationCommand, IRevokeOrganizationUserCommand revokeOrganizationUserCommand, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IAdminRecoverAccountCommand adminRecoverAccountCommand) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -126,6 +130,7 @@ public class OrganizationUsersController : Controller _restoreOrganizationUserCommand = restoreOrganizationUserCommand; _initPendingOrganizationCommand = initPendingOrganizationCommand; _revokeOrganizationUserCommand = revokeOrganizationUserCommand; + _adminRecoverAccountCommand = adminRecoverAccountCommand; } [HttpGet("{id}")] @@ -474,21 +479,27 @@ public class OrganizationUsersController : Controller [HttpPut("{id}/reset-password")] [Authorize] - public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) + public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) { + if (_featureService.IsEnabled(FeatureFlagKeys.AccountRecoveryCommand)) + { + // TODO: remove legacy implementation after feature flag is enabled. + return await PutResetPasswordNew(orgId, id, model); + } + // Get the users role, since provider users aren't a member of the organization we use the owner check var orgUserType = await _currentContext.OrganizationOwner(orgId) ? OrganizationUserType.Owner : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgId)?.Type; if (orgUserType == null) { - throw new NotFoundException(); + return TypedResults.NotFound(); } var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgId, id, model.NewMasterPasswordHash, model.Key); if (result.Succeeded) { - return; + return TypedResults.Ok(); } foreach (var error in result.Errors) @@ -497,9 +508,45 @@ public class OrganizationUsersController : Controller } await Task.Delay(2000); - throw new BadRequestException(ModelState); + return TypedResults.BadRequest(ModelState); } +#nullable enable + // TODO: make sure the route and authorize attributes are maintained when the legacy implementation is removed. + private async Task PutResetPasswordNew(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) + { + var targetOrganizationUser = await _organizationUserRepository.GetByIdAsync(id); + if (targetOrganizationUser == null || targetOrganizationUser.OrganizationId != orgId) + { + return TypedResults.NotFound(); + } + + var authorizationResult = await _authorizationService.AuthorizeAsync(User, targetOrganizationUser, new RecoverAccountAuthorizationRequirement()); + if (!authorizationResult.Succeeded) + { + // Return an informative error to show in the UI. + // The Authorize attribute already prevents enumeration by users outside the organization, so this can be specific. + var failureReason = authorizationResult.Failure?.FailureReasons.FirstOrDefault()?.Message ?? RecoverAccountAuthorizationHandler.FailureReason; + // This should be a 403 Forbidden, but that causes a logout on our client apps so we're using 400 Bad Request instead + return TypedResults.BadRequest(new ErrorResponseModel(failureReason)); + } + + var result = await _adminRecoverAccountCommand.RecoverAccountAsync(orgId, targetOrganizationUser, model.NewMasterPasswordHash, model.Key); + if (result.Succeeded) + { + return TypedResults.Ok(); + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + return TypedResults.BadRequest(ModelState); + } +#nullable disable + [HttpDelete("{id}")] [Authorize] public async Task Remove(Guid orgId, Guid id) diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..5783301a0b --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs @@ -0,0 +1,79 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +public class AdminRecoverAccountCommand(IOrganizationRepository organizationRepository, + IPolicyRepository policyRepository, + IUserRepository userRepository, + IMailService mailService, + IEventService eventService, + IPushNotificationService pushNotificationService, + IUserService userService, + TimeProvider timeProvider) : IAdminRecoverAccountCommand +{ + public async Task RecoverAccountAsync(Guid orgId, + OrganizationUser organizationUser, string newMasterPassword, string key) + { + // Org must be able to use reset password + var org = await organizationRepository.GetByIdAsync(orgId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset."); + } + + // Enterprise policy must be enabled + var resetPasswordPolicy = + await policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Org User must be confirmed and have a ResetPasswordKey + if (organizationUser == null || + organizationUser.Status != OrganizationUserStatusType.Confirmed || + organizationUser.OrganizationId != orgId || + string.IsNullOrEmpty(organizationUser.ResetPasswordKey) || + !organizationUser.UserId.HasValue) + { + throw new BadRequestException("Organization User not valid"); + } + + var user = await userService.GetUserByIdAsync(organizationUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("Cannot reset password of a user with Key Connector."); + } + + var result = await userService.UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = timeProvider.GetUtcNow().UtcDateTime; + user.LastPasswordChangeDate = user.RevisionDate; + user.ForcePasswordReset = true; + user.Key = key; + + await userRepository.ReplaceAsync(user); + await mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.DisplayName()); + await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_AdminResetPassword); + await pushNotificationService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..75babc643e --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs @@ -0,0 +1,24 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +/// +/// A command used to recover an organization user's account by an organization admin. +/// +public interface IAdminRecoverAccountCommand +{ + /// + /// Recovers an organization user's account by resetting their master password. + /// + /// The organization the user belongs to. + /// The organization user being recovered. + /// The user's new master password hash. + /// The user's new master-password-sealed user key. + /// An IdentityResult indicating success or failure. + /// When organization settings, policy, or user state is invalid. + /// When the user does not exist. + Task RecoverAccountAsync(Guid orgId, OrganizationUser organizationUser, + string newMasterPassword, string key); +} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index d147f3908b..fead9947a0 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -142,6 +142,7 @@ public static class FeatureFlagKeys public const string CreateDefaultLocation = "pm-19467-create-default-location"; public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; + public const string AccountRecoveryCommand = "pm-24192-account-recovery-command"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index 417e220ba2..f62a048070 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Security.Claims; +using System.Security.Claims; using Bit.Core.AdminConsole.Context; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Identity; @@ -12,6 +10,14 @@ using Microsoft.AspNetCore.Http; namespace Bit.Core.Context; +/// +/// Provides information about the current HTTP request and the currently authenticated user (if any). +/// This is often (but not exclusively) parsed from the JWT in the current request. +/// +/// +/// This interface suffers from having too much responsibility; consider whether any new code can go in a more +/// specific class rather than adding it here. +/// public interface ICurrentContext { HttpContext HttpContext { get; set; } @@ -59,8 +65,20 @@ public interface ICurrentContext Task EditSubscription(Guid orgId); Task EditPaymentMethods(Guid orgId); Task ViewBillingHistory(Guid orgId); + /// + /// Returns true if the current user is a member of a provider that manages the specified organization. + /// This generally gives the user administrative privileges for the organization. + /// + /// + /// Task ProviderUserForOrgAsync(Guid orgId); + /// + /// Returns true if the current user is a Provider Admin of the specified provider. + /// bool ProviderProviderAdmin(Guid providerId); + /// + /// Returns true if the current user is a member of the specified provider (with any role). + /// bool ProviderUser(Guid providerId); bool ProviderManageUsers(Guid providerId); bool ProviderAccessEventLogs(Guid providerId); diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index da05bc929c..8cfd0a8df1 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.OrganizationAuth; using Bit.Core.AdminConsole.OrganizationAuth.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.Groups; using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Import; @@ -133,6 +134,7 @@ public static class OrganizationServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 5a3428c25a..91bbde949b 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; @@ -92,7 +90,7 @@ public interface IMailService Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); - Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); + Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName); Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email); Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 19705766ed..e8707d13e8 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -674,7 +674,7 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public async Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { var message = CreateDefaultMessage("Your admin has initiated account recovery", email); var model = new AdminResetPasswordViewModel() diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index 1459fab966..5e7c67bd61 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -221,7 +221,7 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { return Task.FromResult(0); } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs new file mode 100644 index 0000000000..cf842d1568 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs @@ -0,0 +1,197 @@ +using System.Net; +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Request.Organizations; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + + public OrganizationUsersControllerPutResetPasswordTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.AccountRecoveryCommand) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"reset-password-test-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + // Enable reset password and policies for the organization + var organizationRepository = _factory.GetService(); + _organization.UseResetPassword = true; + _organization.UsePolicies = true; + await organizationRepository.ReplaceAsync(_organization); + + // Enable the ResetPassword policy + var policyRepository = _factory.GetService(); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = _organization.Id, + Type = PolicyType.ResetPassword, + Enabled = true, + Data = "{}" + }); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + /// + /// Helper method to set the ResetPasswordKey on an organization user, which is required for account recovery + /// + private async Task SetResetPasswordKeyAsync(OrganizationUser orgUser) + { + var organizationUserRepository = _factory.GetService(); + orgUser.ResetPasswordKey = "encrypted-reset-password-key"; + await organizationUserRepository.ReplaceAsync(orgUser); + } + + [Fact] + public async Task PutResetPassword_AsHigherRole_CanRecoverLowerRole() + { + // Arrange + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + var (_, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutResetPassword_AsLowerRole_CannotRecoverHigherRole() + { + // Arrange + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + await _loginHelper.LoginAsync(adminEmail); + + var (_, targetOwnerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.Owner); + await SetResetPasswordKeyAsync(targetOwnerOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOwnerOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Contains(RecoverAccountAuthorizationHandler.FailureReason, model.Message); + } + + [Fact] + public async Task PutResetPassword_CannotRecoverProviderAccount() + { + // Arrange - Create owner who will try to recover the provider account + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + // Create a user who is also a provider user + var (targetUserEmail, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + // Add the target user as a provider user to a different provider + var providerRepository = _factory.GetService(); + var providerUserRepository = _factory.GetService(); + var userRepository = _factory.GetService(); + + var provider = await providerRepository.CreateAsync(new Provider + { + Name = "Test Provider", + BusinessName = "Test Provider Business", + BillingEmail = "provider@example.com", + Type = ProviderType.Msp, + Status = ProviderStatusType.Created, + Enabled = true + }); + + var targetUser = await userRepository.GetByEmailAsync(targetUserEmail); + Assert.NotNull(targetUser); + + await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = provider.Id, + UserId = targetUser.Id, + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin + }); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Equal(RecoverAccountAuthorizationHandler.ProviderFailureReason, model.Message); + } +} diff --git a/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs b/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs new file mode 100644 index 0000000000..92efb641f1 --- /dev/null +++ b/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs @@ -0,0 +1,296 @@ +using System.Security.Claims; +using Bit.Api.AdminConsole.Authorization; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Authorization; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Authorization; + +[SutProviderCustomize] +public class RecoverAccountAuthorizationHandlerTests +{ + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsProvider_TargetUserNotProvider_Authorized( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsNotMemberOrProvider_NotAuthorized( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ can recover the account for a ___ + public static IEnumerable AuthorizedRoleCombinations => new object[][] + { + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(AuthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_RecoverEqualOrLesserRoles_TargetUserNotProvider_Authorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ cannot recover the account for a ___ + public static IEnumerable UnauthorizedRoleCombinations => new object[][] + { + // These roles should fail because you cannot recover a greater role + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true} }, OrganizationUserType.Admin], + + // These roles are never authorized to recover any account + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(UnauthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_InvalidRoles_TargetUserNotProvider_Unauthorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_TargetUserIdIsNull_DoesNotBlock( + SutProvider sutProvider, + OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.UserId = null; + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + // This should shortcut the provider escalation check + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetManyByUserAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsMemberOfAllTargetUserProviders_DoesNotBlock( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(true); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserMissingProviderMembership_Blocks( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + // Not a member of this provider + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(false); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.ProviderFailureReason); + } + + private static void MockOrganizationClaims(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser, + CurrentContextOrganization? currentContextOrganization) + { + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void MockCurrentUserIsProvider(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + sutProvider.GetDependency() + .IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId) + .Returns(true); + } + + private static void MockCurrentUserIsOwner(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = new CurrentContextOrganization + { + Id = targetOrganizationUser.OrganizationId, + Type = OrganizationUserType.Owner + }; + + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void AssertFailed(AuthorizationHandlerContext context, string expectedMessage) + { + Assert.True(context.HasFailed); + var failureReason = Assert.Single(context.FailureReasons); + Assert.Equal(expectedMessage, failureReason.Message); + } +} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs index e5aa03f067..5875cda05a 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs @@ -1,11 +1,14 @@ using System.Security.Claims; +using Bit.Api.AdminConsole.Authorization; using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.Models.Request.Organizations; using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; @@ -16,6 +19,7 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; @@ -30,6 +34,7 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; +using Microsoft.AspNetCore.Mvc.ModelBinding; using NSubstitute; using Xunit; @@ -440,4 +445,153 @@ public class OrganizationUsersControllerTests Assert.Equal("Master Password reset is required, but not provided.", exception.Message); } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_CallsLegacyPath( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); + sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1) + .AdminResetPasswordAsync(OrganizationUserType.Owner, orgId, orgUserId, model.NewMasterPasswordHash, model.Key); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_WhenOrgUserTypeIsNull_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(false); + sutProvider.GetDependency().Organizations.Returns(new List()); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_WhenAdminResetPasswordFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); + sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error 1" })); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationUserNotFound_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns((OrganizationUser)null); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationIdMismatch_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = Guid.NewGuid(); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenAuthorizationFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Failed()); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountSucceeds_ReturnsOk( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1) + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error message" })); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } } diff --git a/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs index 080b8ec62e..1c809f604d 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs @@ -1,4 +1,6 @@ -using AutoFixture; +using System.Reflection; +using AutoFixture; +using AutoFixture.Xunit2; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models.Data; @@ -23,6 +25,7 @@ public class CurrentContextOrganizationCustomization : ICustomization } } +[AttributeUsage(AttributeTargets.Method)] public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribute { public Guid Id { get; set; } @@ -38,3 +41,19 @@ public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribut AccessSecretsManager = AccessSecretsManager }; } + +public class CurrentContextOrganizationAttribute : CustomizeAttribute +{ + public Guid Id { get; set; } + public OrganizationUserType Type { get; set; } = OrganizationUserType.User; + public Permissions Permissions { get; set; } = new(); + public bool AccessSecretsManager { get; set; } = false; + + public override ICustomization GetCustomization(ParameterInfo _) => new CurrentContextOrganizationCustomization + { + Id = Id, + Type = Type, + Permissions = Permissions, + AccessSecretsManager = AccessSecretsManager + }; +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs new file mode 100644 index 0000000000..88025301b6 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs @@ -0,0 +1,296 @@ +using AutoFixture; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Identity; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.AccountRecovery; + +[SutProviderCustomize] +public class AdminRecoverAccountCommandTests +{ + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_Success( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + User user, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + SetupValidUser(sutProvider, user, organizationUser); + SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword); + + // Act + var result = await sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key); + + // Assert + Assert.True(result.Succeeded); + await AssertSuccessAsync(sutProvider, user, key, organization, organizationUser); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_OrganizationDoesNotExist_ThrowsBadRequest( + [OrganizationUser] OrganizationUser organizationUser, + string newMasterPassword, + string key, + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns((Organization)null); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(orgId, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization does not allow password reset.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_OrganizationDoesNotAllowResetPassword_ThrowsBadRequest( + string newMasterPassword, + string key, + Organization organization, + [OrganizationUser] OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + organization.UseResetPassword = false; + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization does not allow password reset.", exception.Message); + } + + public static IEnumerable InvalidPolicies => new object[][] + { + [new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null] + }; + + [Theory] + [BitMemberAutoData(nameof(InvalidPolicies))] + public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest( + Policy resetPasswordPolicy, + string newMasterPassword, + string key, + Organization organization, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) + .Returns(resetPasswordPolicy); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, new OrganizationUser { Id = Guid.NewGuid() }, + newMasterPassword, key)); + Assert.Equal("Organization does not have the password reset policy enabled.", exception.Message); + } + + public static IEnumerable InvalidOrganizationUsers() + { + // Make an organization so we can use its Id + var organization = new Fixture().Create(); + + var nonConfirmed = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = organization.Id, + Status = OrganizationUserStatusType.Invited + }; + yield return [nonConfirmed, organization]; + + var wrongOrganization = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = Guid.NewGuid(), // Different org + ResetPasswordKey = "test-key", + UserId = Guid.NewGuid(), + }; + yield return [wrongOrganization, organization]; + + var nullResetPasswordKey = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = null, + UserId = Guid.NewGuid(), + }; + yield return [nullResetPasswordKey, organization]; + + var emptyResetPasswordKey = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = "", + UserId = Guid.NewGuid(), + }; + yield return [emptyResetPasswordKey, organization]; + + var nullUserId = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = "test-key", + UserId = null, + }; + yield return [nullUserId, organization]; + } + + [Theory] + [BitMemberAutoData(nameof(InvalidOrganizationUsers))] + public async Task RecoverAccountAsync_OrganizationUserIsInvalid_ThrowsBadRequest( + OrganizationUser organizationUser, + Organization organization, + string newMasterPassword, + string key, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization User not valid", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_UserDoesNotExist_ThrowsNotFoundException( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + sutProvider.GetDependency() + .GetUserByIdAsync(organizationUser.UserId!.Value) + .Returns((User)null); + + // Act & Assert + await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_UserUsesKeyConnector_ThrowsBadRequest( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + User user, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + user.UsesKeyConnector = true; + sutProvider.GetDependency() + .GetUserByIdAsync(organizationUser.UserId!.Value) + .Returns(user); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Cannot reset password of a user with Key Connector.", exception.Message); + } + + private static void SetupValidOrganization(SutProvider sutProvider, Organization organization) + { + organization.UseResetPassword = true; + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + } + + private static void SetupValidPolicy(SutProvider sutProvider, Organization organization) + { + var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true }; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) + .Returns(policy); + } + + private static void SetupValidOrganizationUser(OrganizationUser organizationUser, Guid orgId) + { + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organizationUser.OrganizationId = orgId; + organizationUser.ResetPasswordKey = "test-key"; + organizationUser.Type = OrganizationUserType.User; + } + + private static void SetupValidUser(SutProvider sutProvider, User user, OrganizationUser organizationUser) + { + user.Id = organizationUser.UserId!.Value; + user.UsesKeyConnector = false; + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + } + + private static void SetupSuccessfulPasswordUpdate(SutProvider sutProvider, User user, string newMasterPassword) + { + sutProvider.GetDependency() + .UpdatePasswordHash(user, newMasterPassword) + .Returns(IdentityResult.Success); + } + + private static async Task AssertSuccessAsync(SutProvider sutProvider, User user, string key, + Organization organization, OrganizationUser organizationUser) + { + await sutProvider.GetDependency().Received(1).ReplaceAsync( + Arg.Is(u => + u.Id == user.Id && + u.Key == key && + u.ForcePasswordReset == true && + u.RevisionDate == u.AccountRevisionDate && + u.LastPasswordChangeDate == u.RevisionDate)); + + await sutProvider.GetDependency().Received(1).SendAdminResetPasswordEmailAsync( + Arg.Is(user.Email), + Arg.Is(user.Name), + Arg.Is(organization.DisplayName())); + + await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync( + Arg.Is(organizationUser), + Arg.Is(EventType.OrganizationUser_AdminResetPassword)); + + await sutProvider.GetDependency().Received(1).PushLogOutAsync( + Arg.Is(user.Id)); + } +} From 0ea9e2e48ab09d8e111598c95d1d8aaed7560bb7 Mon Sep 17 00:00:00 2001 From: Github Actions Date: Mon, 3 Nov 2025 14:29:04 +0000 Subject: [PATCH 09/14] Bumped version to 2025.11.0 --- Directory.Build.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Directory.Build.props b/Directory.Build.props index f14574a13c..4511202024 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.10.2 + 2025.11.0 Bit.$(MSBuildProjectName) enable From de56b7f3278be3e74cc65c03570ece1481f7bf24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:24:40 +0000 Subject: [PATCH 10/14] [PM-26099] Update public list members endpoint to include collections (#6503) * Add CreateCollectionAsync method to OrganizationTestHelpers for collection creation with user and group associations * Update public MembersController List endpoint to include associated collections in member response model * Update MembersControllerTests to validate collection associations in List endpoint. Add JsonConstructor to AssociationWithPermissionsResponseModel * Refactor MembersController by removing unused IUserService and IApplicationCacheService dependencies. * Remove nullable disable directive from Public MembersController --- .../Public/Controllers/MembersController.cs | 30 ++++------ ...AssociationWithPermissionsResponseModel.cs | 8 ++- .../Controllers/MembersControllerTests.cs | 55 +++++++++++++++---- .../Helpers/OrganizationTestHelpers.cs | 22 ++++++++ 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 7bfe5648b6..3b2e82121d 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Net; +using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; @@ -24,11 +21,9 @@ public class MembersController : Controller private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IGroupRepository _groupRepository; private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; private readonly ICurrentContext _currentContext; private readonly IUpdateOrganizationUserCommand _updateOrganizationUserCommand; private readonly IUpdateOrganizationUserGroupsCommand _updateOrganizationUserGroupsCommand; - private readonly IApplicationCacheService _applicationCacheService; private readonly IPaymentService _paymentService; private readonly IOrganizationRepository _organizationRepository; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; @@ -39,11 +34,9 @@ public class MembersController : Controller IOrganizationUserRepository organizationUserRepository, IGroupRepository groupRepository, IOrganizationService organizationService, - IUserService userService, ICurrentContext currentContext, IUpdateOrganizationUserCommand updateOrganizationUserCommand, IUpdateOrganizationUserGroupsCommand updateOrganizationUserGroupsCommand, - IApplicationCacheService applicationCacheService, IPaymentService paymentService, IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, @@ -53,11 +46,9 @@ public class MembersController : Controller _organizationUserRepository = organizationUserRepository; _groupRepository = groupRepository; _organizationService = organizationService; - _userService = userService; _currentContext = currentContext; _updateOrganizationUserCommand = updateOrganizationUserCommand; _updateOrganizationUserGroupsCommand = updateOrganizationUserGroupsCommand; - _applicationCacheService = applicationCacheService; _paymentService = paymentService; _organizationRepository = organizationRepository; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; @@ -115,19 +106,18 @@ public class MembersController : Controller /// /// /// Returns a list of your organization's members. - /// Member objects listed in this call do not include information about their associated collections. + /// Member objects listed in this call include information about their associated collections. /// [HttpGet] [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] public async Task List() { - var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId.Value); - // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. + var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId!.Value, includeCollections: true); var orgUsersTwoFactorIsEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(organizationUserUserDetails); var memberResponses = organizationUserUserDetails.Select(u => { - return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, null); + return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, u.Collections); }); var response = new ListResponseModel(memberResponses); return new JsonResult(response); @@ -158,7 +148,7 @@ public class MembersController : Controller invite.AccessSecretsManager = hasStandaloneSecretsManager; - var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, + var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId!.Value, null, systemUser: null, invite, model.ExternalId); var response = new MemberResponseModel(user, invite.Collections); return new JsonResult(response); @@ -188,12 +178,12 @@ public class MembersController : Controller var updatedUser = model.ToOrganizationUser(existingUser); var associations = model.Collections?.Select(c => c.ToCollectionAccessSelection()).ToList(); await _updateOrganizationUserCommand.UpdateUserAsync(updatedUser, existingUserType, null, associations, model.Groups); - MemberResponseModel response = null; + MemberResponseModel response; if (existingUser.UserId.HasValue) { var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - response = new MemberResponseModel(existingUserDetails, - await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails), associations); + response = new MemberResponseModel(existingUserDetails!, + await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails!), associations); } else { @@ -242,7 +232,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId.Value, id, null); + await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId!.Value, id, null); return new OkResult(); } @@ -264,7 +254,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); + await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId!.Value, null, id); return new OkResult(); } } diff --git a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs index e319ead8a4..5ff12a2201 100644 --- a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs @@ -1,9 +1,15 @@ -using Bit.Core.Models.Data; +using System.Text.Json.Serialization; +using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Public.Models.Response; public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel { + [JsonConstructor] + public AssociationWithPermissionsResponseModel() : base() + { + } + public AssociationWithPermissionsResponseModel(CollectionAccessSelection selection) { if (selection == null) diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs index 11c60ad57c..2eeba5d47e 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs @@ -64,6 +64,17 @@ public class MembersControllerTests : IClassFixture, IAsy var (userEmail4, orgUser4) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.Admin); + var collection1 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 1", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = false, Manage = true }, + new CollectionAccessSelection { Id = orgUser3.Id, ReadOnly = true, HidePasswords = false, Manage = false } + ]); + + var collection2 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 2", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = true, Manage = false } + ]); + var response = await _client.GetAsync($"/public/members"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); var result = await response.Content.ReadFromJsonAsync>(); @@ -71,23 +82,47 @@ public class MembersControllerTests : IClassFixture, IAsy Assert.Equal(5, result.Data.Count()); // The owner - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner)); + var ownerResult = result.Data.SingleOrDefault(m => m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner); + Assert.NotNull(ownerResult); + Assert.Empty(ownerResult.Collections); - // The custom user + // The custom user with collections var user1Result = result.Data.Single(m => m.Email == userEmail1); Assert.Equal(OrganizationUserType.Custom, user1Result.Type); AssertHelper.AssertPropertyEqual( new PermissionsModel { AccessImportExport = true, ManagePolicies = true, AccessReports = true }, user1Result.Permissions); + // Verify collections + Assert.NotNull(user1Result.Collections); + Assert.Equal(2, user1Result.Collections.Count()); + var user1Collection1 = user1Result.Collections.Single(c => c.Id == collection1.Id); + Assert.False(user1Collection1.ReadOnly); + Assert.False(user1Collection1.HidePasswords); + Assert.True(user1Collection1.Manage); + var user1Collection2 = user1Result.Collections.Single(c => c.Id == collection2.Id); + Assert.False(user1Collection2.ReadOnly); + Assert.True(user1Collection2.HidePasswords); + Assert.False(user1Collection2.Manage); - // Everyone else - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail2 && m.Type == OrganizationUserType.Owner)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail3 && m.Type == OrganizationUserType.User)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail4 && m.Type == OrganizationUserType.Admin)); + // The other owner + var user2Result = result.Data.SingleOrDefault(m => m.Email == userEmail2 && m.Type == OrganizationUserType.Owner); + Assert.NotNull(user2Result); + Assert.Empty(user2Result.Collections); + + // The user with one collection + var user3Result = result.Data.SingleOrDefault(m => m.Email == userEmail3 && m.Type == OrganizationUserType.User); + Assert.NotNull(user3Result); + Assert.NotNull(user3Result.Collections); + Assert.Single(user3Result.Collections); + var user3Collection1 = user3Result.Collections.Single(c => c.Id == collection1.Id); + Assert.True(user3Collection1.ReadOnly); + Assert.False(user3Collection1.HidePasswords); + Assert.False(user3Collection1.Manage); + + // The admin with no collections + var user4Result = result.Data.SingleOrDefault(m => m.Email == userEmail4 && m.Type == OrganizationUserType.Admin); + Assert.NotNull(user4Result); + Assert.Empty(user4Result.Collections); } [Fact] diff --git a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs index 3cd73c4b1c..c23ebff736 100644 --- a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs +++ b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs @@ -151,6 +151,28 @@ public static class OrganizationTestHelpers return group; } + /// + /// Creates a collection with optional user and group associations. + /// + public static async Task CreateCollectionAsync( + ApiApplicationFactory factory, + Guid organizationId, + string name, + IEnumerable? users = null, + IEnumerable? groups = null) + { + var collectionRepository = factory.GetService(); + var collection = new Collection + { + OrganizationId = organizationId, + Name = name, + Type = CollectionType.SharedCollection + }; + + await collectionRepository.CreateAsync(collection, groups, users); + return collection; + } + /// /// Enables the Organization Data Ownership policy for the specified organization. /// From 1e2e4b9d4d369ac4d5b5d3a51b19452c62bf3222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:44:44 +0000 Subject: [PATCH 11/14] [PM-26429] Add validation to policy data and metadata (#6460) * Enhance PolicyRequestModel and SavePolicyRequest with validation for policy data and metadata. * Add integration tests for policy updates to validate handling of invalid data types in PolicyRequestModel and SavePolicyRequest. * Add missing using * Update PolicyRequestModel for null safety by making Data and ValidateAndSerializePolicyData nullable * Add integration tests for public PoliciesController to validate handling of invalid data types in policy updates. * Add PolicyDataValidator class for validating and serializing policy data and metadata based on policy type. * Refactor PolicyRequestModel, SavePolicyRequest, and PolicyUpdateRequestModel to utilize PolicyDataValidator for data validation and serialization, removing redundant methods and improving code clarity. * Update PolicyRequestModel and SavePolicyRequest to initialize Data and Metadata properties with empty dictionaries. * Refactor PolicyDataValidator to remove null checks for input data in validation methods * Rename test methods in SavePolicyRequestTests to reflect handling of empty data and metadata, and remove null assignments in test cases for improved clarity. * Enhance error handling in PolicyDataValidator to include field-specific details in BadRequestException messages. * Enhance PoliciesControllerTests to verify error messages for BadRequest responses by checking for specific field names in the response content. * refactor: Update PolicyRequestModel and SavePolicyRequest to use nullable dictionaries for Data and Metadata properties; enhance validation methods in PolicyDataValidator to handle null cases. * test: Add integration tests for handling policies with null data in PoliciesController * fix: Catch specific JsonException in PolicyDataValidator to improve error handling * test: Add unit tests for PolicyDataValidator to validate and serialize policy data and metadata * test: Update PolicyDataValidatorTests to validate organization data ownership metadata --- .../Models/Request/PolicyRequestModel.cs | 29 +-- .../Models/Request/SavePolicyRequest.cs | 45 +--- .../Request/PolicyUpdateRequestModel.cs | 23 +- .../Utilities/PolicyDataValidator.cs | 81 ++++++++ .../Controllers/PoliciesControllerTests.cs | 196 ++++++++++++++++++ .../Controllers/PoliciesControllerTests.cs | 82 ++++++++ .../Models/Request/SavePolicyRequestTests.cs | 31 +-- .../Utilities/PolicyDataValidatorTests.cs | 59 ++++++ 8 files changed, 463 insertions(+), 83 deletions(-) create mode 100644 src/Core/AdminConsole/Utilities/PolicyDataValidator.cs create mode 100644 test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs diff --git a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs index 0e31deacd1..f9b9c18993 100644 --- a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs @@ -1,11 +1,8 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using System.Text.Json; +using System.ComponentModel.DataAnnotations; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; namespace Bit.Api.AdminConsole.Models.Request; @@ -16,14 +13,20 @@ public class PolicyRequestModel public PolicyType? Type { get; set; } [Required] public bool? Enabled { get; set; } - public Dictionary Data { get; set; } + public Dictionary? Data { get; set; } - public async Task ToPolicyUpdateAsync(Guid organizationId, ICurrentContext currentContext) => new() + public async Task ToPolicyUpdateAsync(Guid organizationId, ICurrentContext currentContext) { - Type = Type!.Value, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, Type!.Value); + var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); + + return new() + { + Type = Type!.Value, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = performedBy + }; + } } diff --git a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs index fcdc49882b..5c1acc1c36 100644 --- a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs +++ b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs @@ -1,10 +1,8 @@ using System.ComponentModel.DataAnnotations; -using System.Text.Json; -using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; -using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request; @@ -17,45 +15,10 @@ public class SavePolicyRequest public async Task ToSavePolicyModelAsync(Guid organizationId, ICurrentContext currentContext) { + var policyUpdate = await Policy.ToPolicyUpdateAsync(organizationId, currentContext); + var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, Policy.Type!.Value); var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); - var updatedPolicy = new PolicyUpdate() - { - Type = Policy.Type!.Value, - OrganizationId = organizationId, - Data = Policy.Data != null ? JsonSerializer.Serialize(Policy.Data) : null, - Enabled = Policy.Enabled.GetValueOrDefault(), - }; - - var metadata = MapToPolicyMetadata(); - - return new SavePolicyModel(updatedPolicy, performedBy, metadata); - } - - private IPolicyMetadataModel MapToPolicyMetadata() - { - if (Metadata == null) - { - return new EmptyMetadataModel(); - } - - return Policy?.Type switch - { - PolicyType.OrganizationDataOwnership => MapToPolicyMetadata(), - _ => new EmptyMetadataModel() - }; - } - - private IPolicyMetadataModel MapToPolicyMetadata() where T : IPolicyMetadataModel, new() - { - try - { - var json = JsonSerializer.Serialize(Metadata); - return CoreHelpers.LoadClassFromJsonData(json); - } - catch - { - return new EmptyMetadataModel(); - } + return new SavePolicyModel(policyUpdate, performedBy, metadata); } } diff --git a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs index eb56690462..34675a6046 100644 --- a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs @@ -1,19 +1,24 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Enums; namespace Bit.Api.AdminConsole.Public.Models.Request; public class PolicyUpdateRequestModel : PolicyBaseModel { - public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) => new() + public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) { - Type = type, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new SystemUser(EventSystemUser.PublicApi) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + + return new() + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = new SystemUser(EventSystemUser.PublicApi) + }; + } } diff --git a/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs new file mode 100644 index 0000000000..84e63f2a20 --- /dev/null +++ b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs @@ -0,0 +1,81 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.Utilities; + +public static class PolicyDataValidator +{ + /// + /// Validates and serializes policy data based on the policy type. + /// + /// The policy data to validate + /// The type of policy + /// Serialized JSON string if data is valid, null if data is null or empty + /// Thrown when data validation fails + public static string? ValidateAndSerialize(Dictionary? data, PolicyType policyType) + { + if (data == null || data.Count == 0) + { + return null; + } + + try + { + var json = JsonSerializer.Serialize(data); + + switch (policyType) + { + case PolicyType.MasterPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.SendOptions: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.ResetPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + } + + return json; + } + catch (JsonException ex) + { + var fieldInfo = !string.IsNullOrEmpty(ex.Path) ? $": field '{ex.Path}' has invalid type" : ""; + throw new BadRequestException($"Invalid data for {policyType} policy{fieldInfo}."); + } + } + + /// + /// Validates and deserializes policy metadata based on the policy type. + /// + /// The policy metadata to validate + /// The type of policy + /// Deserialized metadata model, or EmptyMetadataModel if metadata is null, empty, or validation fails + public static IPolicyMetadataModel ValidateAndDeserializeMetadata(Dictionary? metadata, PolicyType policyType) + { + if (metadata == null || metadata.Count == 0) + { + return new EmptyMetadataModel(); + } + + try + { + var json = JsonSerializer.Serialize(metadata); + + return policyType switch + { + PolicyType.OrganizationDataOwnership => + CoreHelpers.LoadClassFromJsonData(json), + _ => new EmptyMetadataModel() + }; + } + catch (JsonException) + { + return new EmptyMetadataModel(); + } + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs index 1efc2f843d..79c31f956d 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs @@ -211,4 +211,200 @@ public class PoliciesControllerTests : IClassFixture, IAs } } + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minLength", content); // Verify field name is in error message + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "minComplexity", "not a number" }, // Wrong type - should be int + { "minLength", 12 } + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minComplexity", content); // Verify field name is in error message + } + + [Fact] + public async Task PutVNext_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.SingleOrg; + var request = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutVNext_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.TwoFactorAuthentication; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Type = policyType, + Enabled = true, + Data = null + }, + Metadata = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index f034426f98..0b5ab660b9 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -160,4 +160,86 @@ public class PoliciesControllerTests : IClassFixture, IAs Assert.Equal(15, data.MinLength); Assert.Equal(true, data.RequireUpper); } + + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.DisableSend; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs index 057680425a..75236fd719 100644 --- a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs +++ b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs @@ -54,7 +54,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_WithNullData_HandlesCorrectly( + public async Task ToSavePolicyModelAsync_WithEmptyData_HandlesCorrectly( Guid organizationId, Guid userId) { @@ -68,10 +68,8 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act @@ -100,10 +98,8 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act @@ -133,8 +129,7 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -152,7 +147,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithNullMetadata_ReturnsEmptyMetadata( + public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithEmptyMetadata_ReturnsEmptyMetadata( Guid organizationId, Guid userId) { @@ -166,10 +161,8 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null - }, - Metadata = null + Enabled = true + } }; // Act @@ -246,8 +239,7 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.MaximumVaultTimeout, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -280,8 +272,7 @@ public class SavePolicyRequestTests Policy = new PolicyRequestModel { Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = errorDictionary }; diff --git a/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs new file mode 100644 index 0000000000..43725d23e0 --- /dev/null +++ b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs @@ -0,0 +1,59 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; +using Bit.Core.Exceptions; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.Utilities; + +public class PolicyDataValidatorTests +{ + [Fact] + public void ValidateAndSerialize_NullData_ReturnsNull() + { + var result = PolicyDataValidator.ValidateAndSerialize(null, PolicyType.MasterPassword); + + Assert.Null(result); + } + + [Fact] + public void ValidateAndSerialize_ValidData_ReturnsSerializedJson() + { + var data = new Dictionary { { "minLength", 12 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minLength\":12", result); + } + + [Fact] + public void ValidateAndSerialize_InvalidDataType_ThrowsBadRequestException() + { + var data = new Dictionary { { "minLength", "not a number" } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + Assert.Contains("minLength", exception.Message); + } + + [Fact] + public void ValidateAndDeserializeMetadata_NullMetadata_ReturnsEmptyMetadataModel() + { + var result = PolicyDataValidator.ValidateAndDeserializeMetadata(null, PolicyType.SingleOrg); + + Assert.IsType(result); + } + + [Fact] + public void ValidateAndDeserializeMetadata_ValidMetadata_ReturnsModel() + { + var metadata = new Dictionary { { "defaultUserCollectionName", "collection name" } }; + + var result = PolicyDataValidator.ValidateAndDeserializeMetadata(metadata, PolicyType.OrganizationDataOwnership); + + Assert.IsType(result); + } +} From b329305b771e3b4a58c3e9ac852e71bf9e103e51 Mon Sep 17 00:00:00 2001 From: Robyn MacCallum Date: Mon, 3 Nov 2025 11:11:42 -0500 Subject: [PATCH 12/14] Update description for AutomaticAppLogIn policy (#6522) --- src/Core/AdminConsole/Enums/PolicyType.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 3ac14d67f3..09fa4ec955 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -45,7 +45,7 @@ public static class PolicyTypeExtensions PolicyType.MaximumVaultTimeout => "Vault timeout", PolicyType.DisablePersonalVaultExport => "Remove individual vault export", PolicyType.ActivateAutofill => "Active auto-fill", - PolicyType.AutomaticAppLogIn => "Automatically log in users for allowed applications", + PolicyType.AutomaticAppLogIn => "Automatic login with SSO", PolicyType.FreeFamiliesSponsorshipPolicy => "Remove Free Bitwarden Families sponsorship", PolicyType.RemoveUnlockWithPin => "Remove unlock with PIN", PolicyType.RestrictedItemTypesPolicy => "Restricted item types", From bda2bd8ac1398280f55b880f118b2a19f23cdb17 Mon Sep 17 00:00:00 2001 From: Dave <3836813+enmande@users.noreply.github.com> Date: Mon, 3 Nov 2025 12:24:00 -0500 Subject: [PATCH 13/14] fix(base-request-validator) [PM-21153] Recovery Code Not Functioning for SSO-required Users (#6481) * chore(feature-flag-keys) [PM-21153]: Add feature flag key for BaseRequestValidator changes. * fix(base-request-validator) [PM-21153]: Add validation state model for composable validation scenarios. * fix(base-request-validator) [PM-21153]: Update BaseRequestValidator to allow validation scenarios to be composable. * fix(base-request-validator) [PM-21153]: Remove validation state object in favor of validator context, per team discussion. * feat(base-request-validator) [PM-21153]: Update tests to use issue feature flag, both execution paths. * fix(base-request-validator) [PM-21153]: Fix a null dictionary check. * chore(base-request-validator) [PM-21153]: Add unit tests around behavior addressed in this feature. * chore(base-request-validator) [PM-21153]: Update comments for clarity. * chore(base-request-validator-tests) [PM-21153]: Update verbiage for tests. * fix(base-request-validator) [PM-21153]: Update validators to no longer need completed scheme management, use 2FA flag for recovery scenarios. * fix(base-request-validator-tests) [PM-21153]: Customize CustomValidatorRequestContext fixture to allow for setting of request-specific flags as part of the request validation (not eagerly truthy). --- src/Core/Constants.cs | 1 + .../CustomValidatorRequestContext.cs | 11 +- .../RequestValidators/BaseRequestValidator.cs | 553 ++++++++++++++---- .../AutoFixture/RequestValidationFixtures.cs | 41 +- .../BaseRequestValidatorTests.cs | 521 ++++++++++++++--- 5 files changed, 933 insertions(+), 194 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index fead9947a0..ccfa4a6e0e 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -156,6 +156,7 @@ public static class FeatureFlagKeys public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; public const string PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword = "pm-23174-manage-account-recovery-permission-drives-the-need-to-set-master-password"; + public const string RecoveryCodeSupportForSsoRequiredUsers = "pm-21153-recovery-code-support-for-sso-required"; public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; /* Autofill Team */ diff --git a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs index a709a47cb2..e16c8ad695 100644 --- a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs @@ -27,6 +27,12 @@ public class CustomValidatorRequestContext /// public bool TwoFactorRequired { get; set; } = false; /// + /// Whether the user has requested recovery of their 2FA methods using their one-time + /// recovery code. + /// + /// + public bool TwoFactorRecoveryRequested { get; set; } = false; + /// /// This communicates whether or not SSO is required for the user to authenticate. /// public bool SsoRequired { get; set; } = false; @@ -42,10 +48,13 @@ public class CustomValidatorRequestContext /// This will be null if the authentication request is successful. /// public Dictionary CustomResponse { get; set; } - /// /// A validated auth request /// /// public AuthRequest ValidatedAuthRequest { get; set; } + /// + /// Whether the user has requested a Remember Me token for their current device. + /// + public bool RememberMeRequested { get; set; } = false; } diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index b976775aca..224c7a1866 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below + #nullable disable using System.Security.Claims; @@ -68,7 +69,7 @@ public abstract class BaseRequestValidator where T : class IAuthRequestRepository authRequestRepository, IMailService mailService, IUserAccountKeysQuery userAccountKeysQuery - ) + ) { _userManager = userManager; _userService = userService; @@ -93,125 +94,141 @@ public abstract class BaseRequestValidator where T : class protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // 1. We need to check if the user's master password hash is correct. - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) + if (FeatureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) { - await UpdateFailedAuthDetailsAsync(user); - - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); - return; - } - - // 2. Decide if this user belongs to an organization that requires SSO. - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) - { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; - } - - // 3. Check if 2FA is required. - (validatorContext.TwoFactorRequired, var twoFactorOrganization) = - await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); - - // This flag is used to determine if the user wants a rememberMe token sent when - // authentication is successful. - var returnRememberMeToken = false; - - if (validatorContext.TwoFactorRequired) - { - var twoFactorToken = request.Raw["TwoFactorToken"]; - var twoFactorProvider = request.Raw["TwoFactorProvider"]; - var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - // 3a. Response for 2FA required and not provided state. - if (!validTwoFactorRequest || - !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) + var validators = DetermineValidationOrder(context, request, validatorContext); + var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); + if (!allValidationSchemesSuccessful) { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - if (resultDict == null) + // Each validation task is responsible for setting its own non-success status, if applicable. + return; + } + await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, + validatorContext.RememberMeRequested); + } + else + { + // 1. We need to check if the user's master password hash is correct. + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (!valid) + { + await UpdateFailedAuthDetailsAsync(user); + + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return; + } + + // 2. Decide if this user belongs to an organization that requires SSO. + validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); + if (validatorContext.SsoRequired) + { + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return; + } + + // 3. Check if 2FA is required. + (validatorContext.TwoFactorRequired, var twoFactorOrganization) = + await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); + + // This flag is used to determine if the user wants a rememberMe token sent when + // authentication is successful. + var returnRememberMeToken = false; + + if (validatorContext.TwoFactorRequired) + { + var twoFactorToken = request.Raw["TwoFactorToken"]; + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + // 3a. Response for 2FA required and not provided state. + if (!validTwoFactorRequest || + !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(user, twoFactorOrganization); + if (resultDict == null) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + return; + } + + // Include Master Password Policy in 2FA response. + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); + SetTwoFactorResult(context, resultDict); return; } - // Include Master Password Policy in 2FA response. - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); + var twoFactorTokenValid = + await _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); + + // 3b. Response for 2FA required but request is not valid or remember token expired state. + if (!twoFactorTokenValid) + { + // The remember me token has expired. + if (twoFactorProviderType == TwoFactorProviderType.Remember) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(user, twoFactorOrganization); + + // Include Master Password Policy in 2FA response + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); + SetTwoFactorResult(context, resultDict); + } + else + { + await SendFailedTwoFactorEmail(user, twoFactorProviderType); + await UpdateFailedAuthDetailsAsync(user); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); + } + + return; + } + + // 3c. When the 2FA authentication is successful, we can check if the user wants a + // rememberMe token. + var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; + // Check if the user wants a rememberMe token. + if (twoFactorRemember + // if the 2FA auth was rememberMe do not send another token. + && twoFactorProviderType != TwoFactorProviderType.Remember) + { + returnRememberMeToken = true; + } + } + + // 4. Check if the user is logging in from a new device. + var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); + if (!deviceValid) + { + SetValidationErrorResult(context, validatorContext); + await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); return; } - var twoFactorTokenValid = - await _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); - - // 3b. Response for 2FA required but request is not valid or remember token expired state. - if (!twoFactorTokenValid) + // 5. Force legacy users to the web for migration. + if (UserService.IsLegacyUser(user) && request.ClientId != "web") { - // The remember me token has expired. - if (twoFactorProviderType == TwoFactorProviderType.Remember) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - - // Include Master Password Policy in 2FA response - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - } - else - { - await SendFailedTwoFactorEmail(user, twoFactorProviderType); - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - } + await FailAuthForLegacyUserAsync(user, context); return; } - // 3c. When the 2FA authentication is successful, we can check if the user wants a - // rememberMe token. - var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; - // Check if the user wants a rememberMe token. - if (twoFactorRemember - // if the 2FA auth was rememberMe do not send another token. - && twoFactorProviderType != TwoFactorProviderType.Remember) + // TODO: PM-24324 - This should be its own validator at some point. + // 6. Auth request handling + if (validatorContext.ValidatedAuthRequest != null) { - returnRememberMeToken = true; + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); } - } - // 4. Check if the user is logging in from a new device. - var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); - if (!deviceValid) - { - SetValidationErrorResult(context, validatorContext); - await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); - return; + await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); } - - // 5. Force legacy users to the web for migration. - if (UserService.IsLegacyUser(user) && request.ClientId != "web") - { - await FailAuthForLegacyUserAsync(user, context); - return; - } - - // TODO: PM-24324 - This should be its own validator at some point. - // 6. Auth request handling - if (validatorContext.ValidatedAuthRequest != null) - { - validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; - await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); - } - - await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); } protected async Task FailAuthForLegacyUserAsync(User user, T context) @@ -223,6 +240,302 @@ public abstract class BaseRequestValidator where T : class protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); + /// + /// Composer for validation schemes. + /// + /// The current request context. + /// + /// + /// A composed array of validation scheme delegates to evaluate in order. + private Func>[] DetermineValidationOrder(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (RecoveryCodeRequestForSsoRequiredUserScenario()) + { + // Support valid requests to recover 2FA (with account code) for users who require SSO + // by organization membership. + // This requires an evaluation of 2FA validity in front of SSO, and an opportunity for the 2FA + // validation to perform the recovery as part of scheme validation based on the request. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + else + { + // The typical validation scenario. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + + bool RecoveryCodeRequestForSsoRequiredUserScenario() + { + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var twoFactorToken = request.Raw["TwoFactorToken"]; + + // Both provider and token must be present; + // Validity of the token for a given provider will be evaluated by the TwoFactorAuthenticationValidator. + if (string.IsNullOrWhiteSpace(twoFactorProvider) || string.IsNullOrWhiteSpace(twoFactorToken)) + { + return false; + } + + if (!int.TryParse(twoFactorProvider, out var providerValue)) + { + return false; + } + + return providerValue == (int)TwoFactorProviderType.RecoveryCode; + } + } + + /// + /// Processes the validation schemes sequentially. + /// Each validator is responsible for setting error context responses on failure and adding itself to the + /// validatorContext's CompletedValidationSchemes (only) on success. + /// Failure of any scheme to validate will short-circuit the collection, causing the validation error to be + /// returned and further schemes to not be evaluated. + /// + /// The collection of validation schemes as composed in + /// true if all schemes validated successfully, false if any failed. + private static async Task ProcessValidatorsAsync(params Func>[] validators) + { + foreach (var validator in validators) + { + if (!await validator()) + { + return false; + } + } + + return true; + } + + /// + /// Validates the user's Master Password hash. + /// + /// The current request context. + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateMasterPasswordAsync(T context, CustomValidatorRequestContext validatorContext) + { + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (valid) + { + return true; + } + + await UpdateFailedAuthDetailsAsync(user); + + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return false; + } + + /// + /// Validates the user's organization-enforced Single Sign-on (SSO) requirement. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + /// + private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); + if (!validatorContext.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (validatorContext.TwoFactorRequired && + validatorContext.TwoFactorRecoveryRequested) + { + SetSsoResult(context, new Dictionary + { + { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } + }); + return false; + } + + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return false; + } + + /// + /// Validates the user's Multi-Factor Authentication (2FA) scheme. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateTwoFactorAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + (validatorContext.TwoFactorRequired, var twoFactorOrganization) = + await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(validatorContext.User, request); + + if (!validatorContext.TwoFactorRequired) + { + return true; + } + + var twoFactorToken = request.Raw["TwoFactorToken"]; + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + // 3a. Response for 2FA required and not provided state. + if (!validTwoFactorRequest || + !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + if (resultDict == null) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, validatorContext.User); + return false; + } + + // Include Master Password Policy in 2FA response. + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + return false; + } + + var twoFactorTokenValid = + await _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(validatorContext.User, twoFactorOrganization, twoFactorProviderType, + twoFactorToken); + + // 3b. Response for 2FA required but request is not valid or remember token expired state. + if (!twoFactorTokenValid) + { + // The remember me token has expired. + if (twoFactorProviderType == TwoFactorProviderType.Remember) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + + // Include Master Password Policy in 2FA response + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + } + else + { + await SendFailedTwoFactorEmail(validatorContext.User, twoFactorProviderType); + await UpdateFailedAuthDetailsAsync(validatorContext.User); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, + validatorContext.User); + } + + return false; + } + + // 3c. Given a valid token and a successful two-factor verification, if the provider type is Recovery Code, + // recovery will have been performed as part of 2FA validation. This will be relevant for, e.g., SSO users + // who are requesting recovery, but who will still need to log in after 2FA recovery. + if (twoFactorProviderType == TwoFactorProviderType.RecoveryCode) + { + validatorContext.TwoFactorRecoveryRequested = true; + } + + // 3d. When the 2FA authentication is successful, we can check if the user wants a + // rememberMe token. + var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; + // Check if the user wants a rememberMe token. + if (twoFactorRemember + // if the 2FA auth was rememberMe do not send another token. + && twoFactorProviderType != TwoFactorProviderType.Remember) + { + validatorContext.RememberMeRequested = true; + } + + return true; + } + + /// + /// Validates whether the user is logging in from a known device. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateNewDeviceAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); + if (deviceValid) + { + return true; + } + + SetValidationErrorResult(context, validatorContext); + await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); + return false; + } + + /// + /// Validates whether the user should be denied access on a given non-Web client and sent to the Web client + /// for Legacy migration. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateLegacyMigrationAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (!UserService.IsLegacyUser(validatorContext.User) || request.ClientId == "web") + { + return true; + } + + await FailAuthForLegacyUserAsync(validatorContext.User, context); + return false; + } + + /// + /// Validates and updates the auth request's timestamp. + /// + /// + /// true on evaluation and/or completed update of the AuthRequest. + private async Task ValidateAuthRequestAsync(CustomValidatorRequestContext validatorContext) + { + // TODO: PM-24324 - This should be its own validator at some point. + if (validatorContext.ValidatedAuthRequest != null) + { + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); + } + + return true; + } /// /// Responsible for building the response to the client when the user has successfully authenticated. @@ -256,7 +569,7 @@ public abstract class BaseRequestValidator where T : class /// used to associate the failed login with a user /// void [Obsolete("Consider using SetValidationErrorResult to set the validation result, and LogFailedLoginEvent " + - "to log the failure.")] + "to log the failure.")] protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) { if (user != null) @@ -268,7 +581,8 @@ public abstract class BaseRequestValidator where T : class if (_globalSettings.SelfHosted) { _logger.LogWarning(Constants.BypassFiltersEventId, - "Failed login attempt. Is2FARequest: {Is2FARequest} IpAddress: {IpAddress}", twoFactorRequest, CurrentContext.IpAddress); + "Failed login attempt. Is2FARequest: {Is2FARequest} IpAddress: {IpAddress}", twoFactorRequest, + CurrentContext.IpAddress); } await Task.Delay(2000); // Delay for brute force. @@ -292,21 +606,26 @@ public abstract class BaseRequestValidator where T : class formattedMessage = string.Format("Failed login attempt. {0}", $" {CurrentContext.IpAddress}"); break; case EventType.User_FailedLogIn2fa: - formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", $" {CurrentContext.IpAddress}"); + formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", + $" {CurrentContext.IpAddress}"); break; default: formattedMessage = "Failed login attempt."; break; } + _logger.LogWarning(Constants.BypassFiltersEventId, "{FailedLoginMessage}", formattedMessage); } + await Task.Delay(2000); // Delay for brute force. } [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetSsoResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -317,6 +636,7 @@ public abstract class BaseRequestValidator where T : class /// The current grant or token context /// The modified request context containing material used to build the response object protected abstract void SetValidationErrorResult(T context, CustomValidatorRequestContext requestContext); + protected abstract Task SetSuccessResult(T context, User user, List claims, Dictionary customResponse); @@ -343,7 +663,7 @@ public abstract class BaseRequestValidator where T : class // Check if user belongs to any organization with an active SSO policy var ssoRequired = FeatureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired + .SsoRequired : await PolicyService.AnyPoliciesApplicableToUserAsync( user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); if (ssoRequired) @@ -385,7 +705,8 @@ public abstract class BaseRequestValidator where T : class { if (FeatureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) { - await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, CurrentContext.IpAddress); + await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, + CurrentContext.IpAddress); } } @@ -416,16 +737,14 @@ public abstract class BaseRequestValidator where T : class // We need this because we check for changes in the stamp to determine if we need to invalidate token refresh requests, // in the `ProfileService.IsActiveAsync` method. // If we don't store the security stamp in the persisted grant, we won't have the previous value to compare against. - var claims = new List - { - new Claim(Claims.SecurityStamp, user.SecurityStamp) - }; + var claims = new List { new Claim(Claims.SecurityStamp, user.SecurityStamp) }; if (device != null) { claims.Add(new Claim(Claims.Device, device.Identifier)); claims.Add(new Claim(Claims.DeviceType, device.Type.ToString())); } + return claims; } @@ -437,7 +756,8 @@ public abstract class BaseRequestValidator where T : class /// The current request context. /// The device used for authentication. /// Whether to send a 2FA remember token. - private async Task> BuildCustomResponse(User user, T context, Device device, bool sendRememberToken) + private async Task> BuildCustomResponse(User user, T context, Device device, + bool sendRememberToken) { var customResponse = new Dictionary(); if (!string.IsNullOrWhiteSpace(user.PrivateKey)) @@ -459,7 +779,8 @@ public abstract class BaseRequestValidator where T : class customResponse.Add("KdfIterations", user.KdfIterations); customResponse.Add("KdfMemory", user.KdfMemory); customResponse.Add("KdfParallelism", user.KdfParallelism); - customResponse.Add("UserDecryptionOptions", await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); + customResponse.Add("UserDecryptionOptions", + await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); if (sendRememberToken) { @@ -467,6 +788,7 @@ public abstract class BaseRequestValidator where T : class CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); customResponse.Add("TwoFactorToken", token); } + return customResponse; } @@ -474,7 +796,8 @@ public abstract class BaseRequestValidator where T : class /// /// Used to create a list of all possible ways the newly authenticated user can decrypt their vault contents /// - private async Task CreateUserDecryptionOptionsAsync(User user, Device device, ClaimsPrincipal subject) + private async Task CreateUserDecryptionOptionsAsync(User user, Device device, + ClaimsPrincipal subject) { var ssoConfig = await GetSsoConfigurationDataAsync(subject); return await UserDecryptionOptionsBuilder diff --git a/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs b/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs index 5ee3bda956..3063524a57 100644 --- a/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs +++ b/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs @@ -1,6 +1,7 @@ using System.Reflection; using AutoFixture; using AutoFixture.Xunit2; +using Bit.Identity.IdentityServer; using Duende.IdentityServer.Validation; namespace Bit.Identity.Test.AutoFixture; @@ -8,7 +9,8 @@ namespace Bit.Identity.Test.AutoFixture; internal class ValidatedTokenRequestCustomization : ICustomization { public ValidatedTokenRequestCustomization() - { } + { + } public void Customize(IFixture fixture) { @@ -22,10 +24,45 @@ internal class ValidatedTokenRequestCustomization : ICustomization public class ValidatedTokenRequestAttribute : CustomizeAttribute { public ValidatedTokenRequestAttribute() - { } + { + } public override ICustomization GetCustomization(ParameterInfo parameter) { return new ValidatedTokenRequestCustomization(); } } + +internal class CustomValidatorRequestContextCustomization : ICustomization +{ + public CustomValidatorRequestContextCustomization() + { + } + + /// + /// Specific context members like , + /// , and + /// should initialize false, + /// and are made truthy in context upon evaluation of a request. Do not allow AutoFixture to eagerly make these + /// truthy; that is the responsibility of the + /// + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.RememberMeRequested, false) + .With(o => o.TwoFactorRecoveryRequested, false) + .With(o => o.SsoRequired, false)); + } +} + +public class CustomValidatorRequestContextAttribute : CustomizeAttribute +{ + public CustomValidatorRequestContextAttribute() + { + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new CustomValidatorRequestContextCustomization(); + } +} diff --git a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs index 53615cd1d1..e78c7d161c 100644 --- a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs @@ -100,19 +100,30 @@ public class BaseRequestValidatorTests _userAccountKeysQuery); } + private void SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(bool recoveryCodeSupportEnabled) + { + _featureService + .IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers) + .Returns(recoveryCodeSupportEnabled); + } + /* Logic path * ValidateAsync -> UpdateFailedAuthDetailsAsync -> _mailService.SendFailedLoginAttemptsEmailAsync * |-> BuildErrorResultAsync -> _eventService.LogUserEventAsync * (self hosted) |-> _logger.LogWarning() * |-> SetErrorResult */ - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_ContextNotValid_SelfHosted_ShouldBuildErrorResult_ShouldLogWarning( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _globalSettings.SelfHosted = true; _sut.isValid = false; @@ -122,18 +133,23 @@ public class BaseRequestValidatorTests // Assert var logs = _logger.Collector.GetSnapshot(true); - Assert.Contains(logs, l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: "); + Assert.Contains(logs, + l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: "); var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_DeviceNotValidated_ShouldLogError( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass _sut.isValid = true; @@ -141,14 +157,15 @@ public class BaseRequestValidatorTests // 2 -> will result to false with no extra configuration // 3 -> set two factor to be false _twoFactorAuthenticationValidator - .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) - .Returns(Task.FromResult(new Tuple(false, null))); + .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) + .Returns(Task.FromResult(new Tuple(false, null))); // 4 -> set up device validator to fail requestContext.KnownDevice = false; tokenRequest.GrantType = "password"; - _deviceValidator.ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) - .Returns(Task.FromResult(false)); + _deviceValidator + .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(false)); // 5 -> not legacy user _userService.IsLegacyUser(Arg.Any()) @@ -163,13 +180,17 @@ public class BaseRequestValidatorTests .LogUserEventAsync(context.CustomValidatorRequestContext.User.Id, EventType.User_FailedLogIn); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_DeviceValidated_ShouldSucceed( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass _sut.isValid = true; @@ -177,12 +198,13 @@ public class BaseRequestValidatorTests // 2 -> will result to false with no extra configuration // 3 -> set two factor to be false _twoFactorAuthenticationValidator - .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) - .Returns(Task.FromResult(new Tuple(false, null))); + .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) + .Returns(Task.FromResult(new Tuple(false, null))); // 4 -> set up device validator to pass - _deviceValidator.ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) - .Returns(Task.FromResult(true)); + _deviceValidator + .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); // 5 -> not legacy user _userService.IsLegacyUser(Arg.Any()) @@ -202,13 +224,17 @@ public class BaseRequestValidatorTests Assert.False(context.GrantResult.IsError); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_ValidatedAuthRequest_ConsumedOnSuccess( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass _sut.isValid = true; @@ -235,7 +261,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); // 4 -> set up device validator to pass - _deviceValidator.ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + _deviceValidator + .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) .Returns(Task.FromResult(true)); // 5 -> not legacy user @@ -260,13 +287,17 @@ public class BaseRequestValidatorTests ar.AuthenticationDate.HasValue)); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_ValidatedAuthRequest_NotConsumed_When2faRequired( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass _sut.isValid = true; @@ -302,13 +333,17 @@ public class BaseRequestValidatorTests await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any()); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_TwoFactorTokenInvalid_ShouldSendFailedTwoFactorEmail( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -345,13 +380,17 @@ public class BaseRequestValidatorTests Arg.Any()); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_TwoFactorRememberTokenExpired_ShouldNotSendFailedTwoFactorEmail( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -391,28 +430,34 @@ public class BaseRequestValidatorTests // Assert // Verify that the failed 2FA email was NOT sent for remember token expiration await _mailService.DidNotReceive() - .SendFailedTwoFactorAttemptEmailAsync(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()); + .SendFailedTwoFactorAttemptEmailAsync(Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()); } // Test grantTypes that require SSO when a user is in an organization that requires it [Theory] - [BitAutoData("password")] - [BitAutoData("webauthn")] - [BitAutoData("refresh_token")] + [BitAutoData("password", true)] + [BitAutoData("password", false)] + [BitAutoData("webauthn", true)] + [BitAutoData("webauthn", false)] + [BitAutoData("refresh_token", true)] + [BitAutoData("refresh_token", false)] public async Task ValidateAsync_GrantTypes_OrgSsoRequiredTrue_ShouldSetSsoResult( string grantType, + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; context.ValidatedTokenRequest.GrantType = grantType; _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -425,16 +470,21 @@ public class BaseRequestValidatorTests // Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled [Theory] - [BitAutoData("password")] - [BitAutoData("webauthn")] - [BitAutoData("refresh_token")] + [BitAutoData("password", true)] + [BitAutoData("password", false)] + [BitAutoData("webauthn", true)] + [BitAutoData("webauthn", false)] + [BitAutoData("refresh_token", true)] + [BitAutoData("refresh_token", false)] public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredTrue_ShouldSetSsoResult( string grantType, + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -449,23 +499,28 @@ public class BaseRequestValidatorTests // Assert await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); Assert.True(context.GrantResult.IsError); var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; Assert.Equal("SSO authentication is required.", errorResponse.Message); } [Theory] - [BitAutoData("password")] - [BitAutoData("webauthn")] - [BitAutoData("refresh_token")] + [BitAutoData("password", true)] + [BitAutoData("password", false)] + [BitAutoData("webauthn", true)] + [BitAutoData("webauthn", false)] + [BitAutoData("refresh_token", true)] + [BitAutoData("refresh_token", false)] public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredFalse_ShouldSucceed( string grantType, + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -500,24 +555,29 @@ public class BaseRequestValidatorTests // Test grantTypes where SSO would be required but the user is not in an // organization that requires it [Theory] - [BitAutoData("password")] - [BitAutoData("webauthn")] - [BitAutoData("refresh_token")] + [BitAutoData("password", true)] + [BitAutoData("password", false)] + [BitAutoData("webauthn", true)] + [BitAutoData("webauthn", false)] + [BitAutoData("refresh_token", true)] + [BitAutoData("refresh_token", false)] public async Task ValidateAsync_GrantTypes_OrgSsoRequiredFalse_ShouldSucceed( string grantType, + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; context.ValidatedTokenRequest.GrantType = grantType; _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(false)); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(false)); _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -540,20 +600,23 @@ public class BaseRequestValidatorTests await _userRepository.Received(1).ReplaceAsync(Arg.Any()); Assert.False(context.GrantResult.IsError); - } // Test the grantTypes where SSO is in progress or not relevant [Theory] - [BitAutoData("authorization_code")] - [BitAutoData("client_credentials")] + [BitAutoData("authorization_code", true)] + [BitAutoData("authorization_code", false)] + [BitAutoData("client_credentials", true)] + [BitAutoData("client_credentials", false)] public async Task ValidateAsync_GrantTypes_SsoRequiredFalse_ShouldSucceed( string grantType, + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -577,7 +640,7 @@ public class BaseRequestValidatorTests // Assert await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); await _eventService.Received(1).LogUserEventAsync( context.CustomValidatorRequestContext.User.Id, EventType.User_LoggedIn); await _userRepository.Received(1).ReplaceAsync(Arg.Any()); @@ -588,13 +651,17 @@ public class BaseRequestValidatorTests /* Logic Path * ValidateAsync -> UserService.IsLegacyUser -> FailAuthForLegacyUserAsync */ - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_IsLegacyUser_FailAuthForLegacyUserAsync( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = context.CustomValidatorRequestContext.User; user.Key = null; @@ -613,21 +680,27 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError); var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - var expectedMessage = "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support"; + var expectedMessage = + "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support"; Assert.Equal(expectedMessage, errorResponse.Message); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_CustomResponse_NoMasterPassword_ShouldSetUserDecryptionOptions( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); - _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()).Returns(_userDecryptionOptionsBuilder); + _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()) + .Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions { HasMasterPassword = false, @@ -663,19 +736,24 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(KdfType.PBKDF2_SHA256, 654_321, null, null)] - [BitAutoData(KdfType.Argon2id, 11, 128, 5)] + [BitAutoData(true, KdfType.PBKDF2_SHA256, 654_321, null, null)] + [BitAutoData(false, KdfType.PBKDF2_SHA256, 654_321, null, null)] + [BitAutoData(true, KdfType.Argon2id, 11, 128, 5)] + [BitAutoData(false, KdfType.Argon2id, 11, 128, 5)] public async Task ValidateAsync_CustomResponse_MasterPassword_ShouldSetUserDecryptionOptions( + bool featureFlagValue, KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); - _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()).Returns(_userDecryptionOptionsBuilder); + _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()) + .Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions { HasMasterPassword = true, @@ -728,13 +806,17 @@ public class BaseRequestValidatorTests Assert.Equal("test@example.com", userDecryptionOptions.MasterPasswordUnlock.Salt); } - [Theory, BitAutoData] + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_CustomResponse_ShouldIncludeAccountKeys( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var mockAccountKeys = new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -747,11 +829,7 @@ public class BaseRequestValidatorTests "test-wrapped-signing-key", "test-verifying-key" ), - SecurityStateData = new SecurityStateData - { - SecurityState = "test-security-state", - SecurityVersion = 2 - } + SecurityStateData = new SecurityStateData { SecurityState = "test-security-state", SecurityVersion = 2 } }; _userAccountKeysQuery.Run(Arg.Any()).Returns(mockAccountKeys); @@ -759,7 +837,8 @@ public class BaseRequestValidatorTests _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); - _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()).Returns(_userDecryptionOptionsBuilder); + _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()) + .Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions { HasMasterPassword = true, @@ -808,13 +887,18 @@ public class BaseRequestValidatorTests Assert.Equal("test-security-state", accountKeysResponse.SecurityState.SecurityState); Assert.Equal(2, accountKeysResponse.SecurityState.SecurityVersion); } - [Theory, BitAutoData] + + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_CustomResponse_AccountKeysQuery_SkippedWhenPrivateKeyIsNull( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) + bool featureFlagValue, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); requestContext.User.PrivateKey = null; var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -833,13 +917,18 @@ public class BaseRequestValidatorTests // Verify that the account keys query wasn't called. await _userAccountKeysQuery.Received(0).Run(Arg.Any()); } - [Theory, BitAutoData] + + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] public async Task ValidateAsync_CustomResponse_AccountKeysQuery_CalledWithCorrectUser( + bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var expectedUser = requestContext.User; _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData @@ -853,7 +942,8 @@ public class BaseRequestValidatorTests _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); - _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()).Returns(_userDecryptionOptionsBuilder); + _userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any()) + .Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions())); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -874,6 +964,285 @@ public class BaseRequestValidatorTests await _userAccountKeysQuery.Received(1).Run(Arg.Is(u => u.Id == expectedUser.Id)); } + /// + /// Tests the core PM-21153 feature: SSO-required users can use recovery codes to disable 2FA, + /// but must then authenticate via SSO with a descriptive message about the recovery. + /// This test validates: + /// 1. Validation order is changed (2FA before SSO) when recovery code is provided + /// 2. Recovery code successfully validates and sets TwoFactorRecoveryRequested flag + /// 3. SSO validation then fails with recovery-specific message + /// 4. User is NOT logged in (must authenticate via IdP) + /// + [Theory] + [BitAutoData(true)] // Feature flag ON - new behavior + [BitAutoData(false)] // Feature flag OFF - should fail at SSO before 2FA recovery + public async Task ValidateAsync_RecoveryCodeForSsoRequiredUser_BlocksWithDescriptiveMessage( + bool featureFlagEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); + var context = CreateContext(tokenRequest, requestContext, grantResult); + var user = requestContext.User; + + // Reset state that AutoFixture may have populated + requestContext.TwoFactorRecoveryRequested = false; + requestContext.RememberMeRequested = false; + + // 1. Master password is valid + _sut.isValid = true; + + // 2. SSO is required (this user is in an org that requires SSO) + _policyService.AnyPoliciesApplicableToUserAsync( + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); + + // 3. 2FA is required + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(user, tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + // 4. Provide a RECOVERY CODE (this triggers the special validation order) + tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); + tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-12345"; + + // 5. Recovery code is valid (UserService.RecoverTwoFactorAsync will be called internally) + _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-12345") + .Returns(Task.FromResult(true)); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery"); + + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + + if (featureFlagEnabled) + { + // NEW BEHAVIOR: Recovery succeeds, then SSO blocks with descriptive message + Assert.Equal( + "Two-factor recovery has been performed. SSO authentication is required.", + errorResponse.Message); + + // Verify recovery was marked + Assert.True(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested flag should be set"); + } + else + { + // LEGACY BEHAVIOR: SSO blocks BEFORE recovery can happen + Assert.Equal( + "SSO authentication is required.", + errorResponse.Message); + + // Recovery never happened because SSO checked first + Assert.False(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested should be false (SSO blocked first)"); + } + + // In both cases: User is NOT logged in + await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn); + } + + /// + /// Tests that validation order changes when a recovery code is PROVIDED (even if invalid). + /// This ensures the RecoveryCodeRequestForSsoRequiredUserScenario() logic is based on + /// request structure, not validation outcome. An SSO-required user who provides an + /// INVALID recovery code should: + /// 1. Have 2FA validated BEFORE SSO (new order) + /// 2. Get a 2FA error (invalid token) + /// 3. NOT get the recovery-specific SSO message (because recovery didn't complete) + /// 4. NOT be logged in + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_InvalidRecoveryCodeForSsoRequiredUser_FailsAt2FA( + bool featureFlagEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); + var context = CreateContext(tokenRequest, requestContext, grantResult); + var user = requestContext.User; + + // 1. Master password is valid + _sut.isValid = true; + + // 2. SSO is required + _policyService.AnyPoliciesApplicableToUserAsync( + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); + + // 3. 2FA is required + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(user, tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + // 4. Provide a RECOVERY CODE (triggers validation order change) + tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); + tokenRequest.Raw["TwoFactorToken"] = "INVALID-recovery-code"; + + // 5. Recovery code is INVALID + _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "INVALID-recovery-code") + .Returns(Task.FromResult(false)); + + // 6. Setup for failed 2FA email (if feature flag enabled) + _featureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail).Returns(true); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code"); + + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + + if (featureFlagEnabled) + { + // NEW BEHAVIOR: 2FA is checked first (due to recovery code request), fails with 2FA error + Assert.Equal( + "Two-step token is invalid. Try again.", + errorResponse.Message); + + // Recovery was attempted but failed - flag should NOT be set + Assert.False(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested should be false (recovery failed)"); + + // Verify failed 2FA email was sent + await _mailService.Received(1).SendFailedTwoFactorAttemptEmailAsync( + user.Email, + TwoFactorProviderType.RecoveryCode, + Arg.Any(), + Arg.Any()); + + // Verify failed login event was logged + await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_FailedLogIn2fa); + } + else + { + // LEGACY BEHAVIOR: SSO is checked first, blocks before 2FA + Assert.Equal( + "SSO authentication is required.", + errorResponse.Message); + + // 2FA validation never happened + await _mailService.DidNotReceive().SendFailedTwoFactorAttemptEmailAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + } + + // In both cases: User is NOT logged in + await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn); + + // Verify user failed login count was updated (in new behavior path) + if (featureFlagEnabled) + { + await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => + u.Id == user.Id && u.FailedLoginCount > 0)); + } + } + + /// + /// Tests that non-SSO users can successfully use recovery codes to disable 2FA and log in. + /// This validates: + /// 1. Validation order changes to 2FA-first when recovery code is provided + /// 2. Recovery code validates successfully + /// 3. SSO check passes (user not in SSO-required org) + /// 4. User successfully logs in + /// 5. TwoFactorRecoveryRequested flag is set (for logging/audit purposes) + /// This is the "happy path" for recovery code usage. + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_RecoveryCodeForNonSsoUser_SuccessfulLogin( + bool featureFlagEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); + var context = CreateContext(tokenRequest, requestContext, grantResult); + var user = requestContext.User; + + // 1. Master password is valid + _sut.isValid = true; + + // 2. SSO is NOT required (this is a regular user, not in SSO org) + _policyService.AnyPoliciesApplicableToUserAsync( + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(false)); + + // 3. 2FA is required + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(user, tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + // 4. Provide a RECOVERY CODE + tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); + tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-67890"; + + // 5. Recovery code is valid + _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-67890") + .Returns(Task.FromResult(true)); + + // 6. Device validation passes + _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 7. User is not legacy + _userService.IsLegacyUser(Arg.Any()) + .Returns(false); + + // 8. Setup user account keys for successful login response + _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( + "test-private-key", + "test-public-key" + ) + }); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.False(context.GrantResult.IsError, "Authentication should succeed for non-SSO user with valid recovery code"); + + // Verify user successfully logged in + await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_LoggedIn); + + // Verify failed login count was reset (successful login) + await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => + u.Id == user.Id && u.FailedLoginCount == 0)); + + if (featureFlagEnabled) + { + // NEW BEHAVIOR: Recovery flag should be set for audit purposes + Assert.True(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested flag should be set for audit/logging"); + } + else + { + // LEGACY BEHAVIOR: Recovery flag doesn't exist, but login still succeeds + // (SSO check happens before 2FA in legacy, but user is not SSO-required so both pass) + Assert.False(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested should be false in legacy mode"); + } + } + private BaseRequestValidationContextFake CreateContext( ValidatedTokenRequest tokenRequest, CustomValidatorRequestContext requestContext, From b4d6f3cb3536e5bd33d76cc711bb777a3715cb0c Mon Sep 17 00:00:00 2001 From: Vincent Salucci <26154748+vincentsalucci@users.noreply.github.com> Date: Mon, 3 Nov 2025 13:32:09 -0600 Subject: [PATCH 14/14] chore: fix provider account recovery flag key, refs PM-24192 (#6533) --- src/Core/Constants.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index ccfa4a6e0e..78f1db5228 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -142,7 +142,7 @@ public static class FeatureFlagKeys public const string CreateDefaultLocation = "pm-19467-create-default-location"; public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; - public const string AccountRecoveryCommand = "pm-24192-account-recovery-command"; + public const string AccountRecoveryCommand = "pm-25581-prevent-provider-account-recovery"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence";