diff --git a/.github/ISSUE_TEMPLATE/bw-lite.yml b/.github/ISSUE_TEMPLATE/bw-lite.yml index f46f4b3e37..cc36164e8f 100644 --- a/.github/ISSUE_TEMPLATE/bw-lite.yml +++ b/.github/ISSUE_TEMPLATE/bw-lite.yml @@ -1,4 +1,4 @@ -name: Bitwarden Lite Deployment Bug Report +name: Bitwarden lite Deployment Bug Report description: File a bug report labels: [bug, bw-lite-deploy] body: @@ -74,7 +74,7 @@ body: id: epic-label attributes: label: Issue-Link - description: Link to our pinned issue, tracking all Bitwarden Lite + description: Link to our pinned issue, tracking all Bitwarden lite value: | https://github.com/bitwarden/server/issues/2480 validations: diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 6a23a7e832..34b59db925 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -63,7 +63,6 @@ }, { matchPackageNames: [ - "Azure.Extensions.AspNetCore.DataProtection.Blobs", "DuoUniversal", "Fido2.AspNet", "Duende.IdentityServer", @@ -137,6 +136,7 @@ "AspNetCoreRateLimit", "AspNetCoreRateLimit.Redis", "Azure.Data.Tables", + "Azure.Extensions.AspNetCore.DataProtection.Blobs", "Azure.Messaging.EventGrid", "Azure.Messaging.ServiceBus", "Azure.Storage.Blobs", diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 877281ccb0..ace6dfdc5d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -185,13 +185,6 @@ jobs: - name: Log in to ACR - production subscription run: az acr login -n bitwardenprod - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat - uses: bitwarden/gh-actions/get-keyvault-secrets@main - with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" - ########## Generate image tag and build Docker image ########## - name: Generate Docker image tag id: tag @@ -250,8 +243,6 @@ jobs: linux/arm64 push: true tags: ${{ steps.image-tags.outputs.tags }} - secrets: | - "GH_PAT=${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}" - name: Install Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' @@ -280,7 +271,7 @@ jobs: output-format: sarif - name: Upload Grype results to GitHub - uses: github/codeql-action/upload-sarif@0499de31b99561a6d14a36a5f662c2a54f91beee # v4.31.2 + uses: github/codeql-action/upload-sarif@e12f0178983d466f2f6028f5cc7a6d786fd97f4b # v4.31.4 with: sarif_file: ${{ steps.container-scan.outputs.sarif }} sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }} @@ -479,20 +470,27 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main - - name: Trigger Bitwarden Lite build + - name: Generate GH App token + uses: actions/create-github-app-token@67018539274d69449ef7c02e8e71183d1719ab42 # v2.1.4 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + + - name: Trigger Bitwarden lite build uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', @@ -520,20 +518,27 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main + - name: Generate GH App token + uses: actions/create-github-app-token@67018539274d69449ef7c02e8e71183d1719ab42 # v2.1.4 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + - name: Trigger k8s deploy uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index 20bc67bc6b..449855ee35 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -62,7 +62,7 @@ jobs: docker compose --profile mssql --profile postgres --profile mysql up -d shell: pwsh - - name: Add MariaDB for Bitwarden Lite + - name: Add MariaDB for Bitwarden lite # Use a different port than MySQL run: | docker run --detach --name mariadb --env MARIADB_ROOT_PASSWORD=mariadb-password -p 4306:3306 mariadb:10 @@ -133,7 +133,7 @@ jobs: # Default Sqlite BW_TEST_DATABASES__3__TYPE: "Sqlite" BW_TEST_DATABASES__3__CONNECTIONSTRING: "Data Source=${{ runner.temp }}/test.db" - # Bitwarden Lite MariaDB + # Bitwarden lite MariaDB BW_TEST_DATABASES__4__TYPE: "MySql" BW_TEST_DATABASES__4__CONNECTIONSTRING: "server=localhost;port=4306;uid=root;pwd=mariadb-password;database=vault_dev;Allow User Variables=true" run: dotnet test --logger "trx;LogFileName=infrastructure-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" diff --git a/Directory.Build.props b/Directory.Build.props index 3e55b8a8cc..d0998430c4 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.11.1 + 2025.12.0 Bit.$(MSBuildProjectName) enable diff --git a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs index 434c265f26..219e6846bd 100644 --- a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs +++ b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs @@ -19,7 +19,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable public virtual async Task StartAsync(CancellationToken cancellationToken) { // Wait 20 seconds to allow database to come online - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); var maxMigrationAttempts = 10; for (var i = 1; i <= maxMigrationAttempts; i++) @@ -41,7 +41,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable { _logger.LogError(e, "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); } } } diff --git a/src/Api/Api.csproj b/src/Api/Api.csproj index 138549e92d..48fedfc8c1 100644 --- a/src/Api/Api.csproj +++ b/src/Api/Api.csproj @@ -33,7 +33,7 @@ - + diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs index e2bca930d1..57140317e3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs @@ -4,6 +4,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; @@ -16,19 +18,22 @@ public class SavePolicyCommand : ISavePolicyCommand private readonly IReadOnlyDictionary _policyValidators; private readonly TimeProvider _timeProvider; private readonly IPostSavePolicySideEffect _postSavePolicySideEffect; + private readonly IPushNotificationService _pushNotificationService; public SavePolicyCommand(IApplicationCacheService applicationCacheService, IEventService eventService, IPolicyRepository policyRepository, IEnumerable policyValidators, TimeProvider timeProvider, - IPostSavePolicySideEffect postSavePolicySideEffect) + IPostSavePolicySideEffect postSavePolicySideEffect, + IPushNotificationService pushNotificationService) { _applicationCacheService = applicationCacheService; _eventService = eventService; _policyRepository = policyRepository; _timeProvider = timeProvider; _postSavePolicySideEffect = postSavePolicySideEffect; + _pushNotificationService = pushNotificationService; var policyValidatorsDict = new Dictionary(); foreach (var policyValidator in policyValidators) @@ -75,6 +80,8 @@ public class SavePolicyCommand : ISavePolicyCommand await _policyRepository.UpsertAsync(policy); await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + await PushPolicyUpdateToClients(policy.OrganizationId, policy); + return policy; } @@ -152,4 +159,17 @@ public class SavePolicyCommand : ISavePolicyCommand var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdate.Type); return (savedPoliciesDict, currentPolicy); } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => this._pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs index 5d40cb211f..38e417d085 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs @@ -5,6 +5,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Int using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; @@ -15,7 +17,8 @@ public class VNextSavePolicyCommand( IPolicyRepository policyRepository, IEnumerable policyUpdateEventHandlers, TimeProvider timeProvider, - IPolicyEventHandlerFactory policyEventHandlerFactory) + IPolicyEventHandlerFactory policyEventHandlerFactory, + IPushNotificationService pushNotificationService) : IVNextSavePolicyCommand { @@ -74,7 +77,7 @@ public class VNextSavePolicyCommand( policy.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; await policyRepository.UpsertAsync(policy); - + await PushPolicyUpdateToClients(policyUpdateRequest.OrganizationId, policy); return policy; } @@ -192,4 +195,17 @@ public class VNextSavePolicyCommand( var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type); return savedPoliciesDict; } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 84ec4acd69..eed6ded12c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -38,6 +38,7 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); } [Obsolete("Use AddPolicyUpdateEvents instead.")] diff --git a/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c932eb0c34 --- /dev/null +++ b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,23 @@ +using Bit.Core.Entities; + +namespace Bit.Core.Auth.Sso; + +/// +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public interface IUserSsoOrganizationIdentifierQuery +{ + /// + /// Retrieves the SSO organization identifier for a confirmed organization user. + /// If there is more than one organization a User is associated with, we return null. If there are more than one + /// organization there is no way to know which organization the user wishes to authenticate with. + /// Owners and Admins who are not subject to the SSO required policy cannot utilize this flow, since they may have + /// multiple organizations with different SSO configurations. + /// + /// The ID of the to retrieve the SSO organization for. _Not_ an . + /// + /// The organization identifier if the user is a confirmed member of an organization with SSO configured, + /// otherwise null + /// + Task GetSsoOrganizationIdentifierAsync(Guid userId); +} diff --git a/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c0751e1f1a --- /dev/null +++ b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,38 @@ +using Bit.Core.Enums; +using Bit.Core.Repositories; + +namespace Bit.Core.Auth.Sso; + +/// +/// TODO : PM-28846 review data structures as they relate to this query +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public class UserSsoOrganizationIdentifierQuery( + IOrganizationUserRepository _organizationUserRepository, + IOrganizationRepository _organizationRepository) : IUserSsoOrganizationIdentifierQuery +{ + /// + public async Task GetSsoOrganizationIdentifierAsync(Guid userId) + { + // Get all confirmed organization memberships for the user + var organizationUsers = await _organizationUserRepository.GetManyByUserAsync(userId); + + // we can only confidently return the correct SsoOrganizationIdentifier if there is exactly one Organization. + // The user must also be in the Confirmed status. + var confirmedOrgUsers = organizationUsers.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed); + if (confirmedOrgUsers.Count() != 1) + { + return null; + } + + var confirmedOrgUser = confirmedOrgUsers.Single(); + var organization = await _organizationRepository.GetByIdAsync(confirmedOrgUser.OrganizationId); + + if (organization == null) + { + return null; + } + + return organization.Identifier; + } +} diff --git a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs index 53bd8bdba2..7c50f7f17b 100644 --- a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs +++ b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ - - +using Bit.Core.Auth.Sso; using Bit.Core.Auth.UserFeatures.DeviceTrust; using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Auth.UserFeatures.Registration.Implementations; @@ -29,6 +28,7 @@ public static class UserServiceCollectionExtensions services.AddWebAuthnLoginCommands(); services.AddTdeOffboardingPasswordCommands(); services.AddTwoFactorQueries(); + services.AddSsoQueries(); } public static void AddDeviceTrustCommands(this IServiceCollection services) @@ -69,4 +69,9 @@ public static class UserServiceCollectionExtensions { services.AddScoped(); } + + private static void AddSsoQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 5d2cd54489..af5b738cd0 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -166,6 +166,7 @@ public static class FeatureFlagKeys public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; + public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; /* Autofill Team */ public const string IdpAutoSubmitLogin = "idp-auto-submit-login"; @@ -202,14 +203,11 @@ public static class FeatureFlagKeys /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; - public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service"; public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration"; public const string Argon2Default = "argon2-default"; public const string UserkeyRotationV2 = "userkey-rotation-v2"; public const string SSHKeyItemVaultItem = "ssh-key-vault-item"; - public const string UserSdkForDecryption = "use-sdk-for-decryption"; public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation"; - public const string PM17987_BlockType0 = "pm-17987-block-type-0"; public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings"; public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data"; public const string WindowsBiometricsV2 = "pm-25373-windows-biometrics-v2"; diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index 5d9b5a1759..6067c60556 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -38,10 +38,6 @@ public class CurrentContext( public virtual List Providers { get; set; } public virtual Guid? InstallationId { get; set; } public virtual Guid? OrganizationId { get; set; } - public virtual bool CloudflareWorkerProxied { get; set; } - public virtual bool IsBot { get; set; } - public virtual bool MaybeBot { get; set; } - public virtual int? BotScore { get; set; } public virtual string ClientId { get; set; } public virtual Version ClientVersion { get; set; } public virtual bool ClientVersionIsPrerelease { get; set; } @@ -70,27 +66,6 @@ public class CurrentContext( DeviceType = dType; } - if (!BotScore.HasValue && httpContext.Request.Headers.TryGetValue("X-Cf-Bot-Score", out var cfBotScore) && - int.TryParse(cfBotScore, out var parsedBotScore)) - { - BotScore = parsedBotScore; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Worked-Proxied", out var cfWorkedProxied)) - { - CloudflareWorkerProxied = cfWorkedProxied == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Is-Bot", out var cfIsBot)) - { - IsBot = cfIsBot == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Maybe-Bot", out var cfMaybeBot)) - { - MaybeBot = cfMaybeBot == "1"; - } - if (httpContext.Request.Headers.TryGetValue("Bitwarden-Client-Version", out var bitWardenClientVersion) && Version.TryParse(bitWardenClientVersion, out var cVersion)) { ClientVersion = cVersion; diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index f62a048070..d527cdd363 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -31,9 +31,6 @@ public interface ICurrentContext Guid? InstallationId { get; set; } Guid? OrganizationId { get; set; } IdentityClientType IdentityClientType { get; set; } - bool IsBot { get; set; } - bool MaybeBot { get; set; } - int? BotScore { get; set; } string ClientId { get; set; } Version ClientVersion { get; set; } bool ClientVersionIsPrerelease { get; set; } diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 1be6e52854..e26cc26b4a 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -25,12 +25,12 @@ - + - - - + + + @@ -60,9 +60,9 @@ - - - + + + diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index a622b98e05..ec39c495aa 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -1,4 +1,5 @@ -using Bit.Core.Enums; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; using Bit.Core.NotificationCenter.Enums; namespace Bit.Core.Models; @@ -103,3 +104,9 @@ public class LogOutPushNotification public Guid UserId { get; set; } public PushNotificationLogOutReason? Reason { get; set; } } + +public class SyncPolicyPushNotification +{ + public Guid OrganizationId { get; set; } + public required Policy Policy { get; set; } +} diff --git a/src/Core/Platform/Push/PushType.cs b/src/Core/Platform/Push/PushType.cs index 93eca86243..9a601ab0d3 100644 --- a/src/Core/Platform/Push/PushType.cs +++ b/src/Core/Platform/Push/PushType.cs @@ -95,5 +95,8 @@ public enum PushType : byte OrganizationBankAccountVerified = 23, [NotificationInfo("@bitwarden/team-billing-dev", typeof(Models.ProviderBankAccountVerifiedPushNotification))] - ProviderBankAccountVerified = 24 + ProviderBankAccountVerified = 24, + + [NotificationInfo("@bitwarden/team-admin-console-dev", typeof(Models.SyncPolicyPushNotification))] + PolicyChanged = 25, } diff --git a/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs new file mode 100644 index 0000000000..4787125045 --- /dev/null +++ b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs @@ -0,0 +1,30 @@ +namespace Bit.Identity.IdentityServer.RequestValidationConstants; + +public static class CustomResponseConstants +{ + public static class ResponseKeys + { + /// + /// Identifies the error model returned in the custom response when an error occurs. + /// + public static string ErrorModel => "ErrorModel"; + /// + /// This Key is used when a user is in a single organization that requires SSO authentication. The identifier + /// is used by the client to speed the redirection to the correct IdP for the user's organization. + /// + public static string SsoOrganizationIdentifier => "SsoOrganizationIdentifier"; + } +} + +public static class SsoConstants +{ + /// + /// These are messages and errors we return when SSO Validation is unsuccessful + /// + public static class RequestErrors + { + public static string SsoRequired => "sso_required"; + public static string SsoRequiredDescription => "Sso authentication is required."; + public static string SsoTwoFactorRecoveryDescription => "Two-factor recovery has been performed. SSO authentication is required."; + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index 224c7a1866..fdc70b0edf 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -34,6 +34,7 @@ public abstract class BaseRequestValidator where T : class private readonly IEventService _eventService; private readonly IDeviceValidator _deviceValidator; private readonly ITwoFactorAuthenticationValidator _twoFactorAuthenticationValidator; + private readonly ISsoRequestValidator _ssoRequestValidator; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; @@ -43,7 +44,7 @@ public abstract class BaseRequestValidator where T : class protected ICurrentContext CurrentContext { get; } protected IPolicyService PolicyService { get; } - protected IFeatureService FeatureService { get; } + protected IFeatureService _featureService { get; } protected ISsoConfigRepository SsoConfigRepository { get; } protected IUserService _userService { get; } protected IUserDecryptionOptionsBuilder UserDecryptionOptionsBuilder { get; } @@ -56,6 +57,7 @@ public abstract class BaseRequestValidator where T : class IEventService eventService, IDeviceValidator deviceValidator, ITwoFactorAuthenticationValidator twoFactorAuthenticationValidator, + ISsoRequestValidator ssoRequestValidator, IOrganizationUserRepository organizationUserRepository, ILogger logger, ICurrentContext currentContext, @@ -76,13 +78,14 @@ public abstract class BaseRequestValidator where T : class _eventService = eventService; _deviceValidator = deviceValidator; _twoFactorAuthenticationValidator = twoFactorAuthenticationValidator; + _ssoRequestValidator = ssoRequestValidator; _organizationUserRepository = organizationUserRepository; _logger = logger; CurrentContext = currentContext; _globalSettings = globalSettings; PolicyService = policyService; _userRepository = userRepository; - FeatureService = featureService; + _featureService = featureService; SsoConfigRepository = ssoConfigRepository; UserDecryptionOptionsBuilder = userDecryptionOptionsBuilder; PolicyRequirementQuery = policyRequirementQuery; @@ -94,7 +97,7 @@ public abstract class BaseRequestValidator where T : class protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - if (FeatureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) + if (_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) { var validators = DetermineValidationOrder(context, request, validatorContext); var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); @@ -120,15 +123,29 @@ public abstract class BaseRequestValidator where T : class } // 2. Decide if this user belongs to an organization that requires SSO. - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) + // TODO: Clean up Feature Flag: Remove this if block: PM-28281 + if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; + validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); + if (validatorContext.SsoRequired) + { + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return; + } + } + else + { + var ssoValid = await _ssoRequestValidator.ValidateAsync(user, request, validatorContext); + if (!ssoValid) + { + // SSO is required + SetValidationErrorResult(context, validatorContext); + return; + } } // 3. Check if 2FA is required. @@ -355,36 +372,51 @@ public abstract class BaseRequestValidator where T : class private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); - if (!validatorContext.SsoRequired) + // TODO: Clean up Feature Flag: Remove this if block: PM-28281 + if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) { - return true; - } + validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); + if (!validatorContext.SsoRequired) + { + return true; + } - // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are - // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and - // review their new recovery token if desired. - // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. - // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been - // evaluated, and recovery will have been performed if requested. - // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect - // to /login. - if (validatorContext.TwoFactorRequired && - validatorContext.TwoFactorRecoveryRequested) - { - SetSsoResult(context, new Dictionary + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (validatorContext.TwoFactorRequired && + validatorContext.TwoFactorRecoveryRequested) + { + SetSsoResult(context, new Dictionary { { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } }); + return false; + } + + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); return false; } - - SetSsoResult(context, - new Dictionary + else + { + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return false; + return true; + } + + SetValidationErrorResult(context, validatorContext); + return ssoValid; + } } /// @@ -651,6 +683,7 @@ public abstract class BaseRequestValidator where T : class /// user trying to login /// magic string identifying the grant type requested /// true if sso required; false if not required or already in process + [Obsolete("This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] private async Task RequireSsoLoginAsync(User user, string grantType) { if (grantType == "authorization_code" || grantType == "client_credentials") @@ -661,7 +694,7 @@ public abstract class BaseRequestValidator where T : class } // Check if user belongs to any organization with an active SSO policy - var ssoRequired = FeatureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await PolicyRequirementQuery.GetAsync(user.Id)) .SsoRequired : await PolicyService.AnyPoliciesApplicableToUserAsync( @@ -703,7 +736,7 @@ public abstract class BaseRequestValidator where T : class private async Task SendFailedTwoFactorEmail(User user, TwoFactorProviderType failedAttemptType) { - if (FeatureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) + if (_featureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) { await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, CurrentContext.IpAddress); diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 64156ea5f3..4d75da92fe 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -36,6 +36,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -56,6 +57,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public interface ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. Sets the error state in the if SSO is required. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context); +} diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index d69d521ef7..ea2c021f63 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -31,6 +31,7 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -50,6 +51,7 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public class SsoRequestValidator( + IPolicyService _policyService, + IFeatureService _featureService, + IUserSsoOrganizationIdentifierQuery _userSsoOrganizationIdentifierQuery, + IPolicyRequirementQuery _policyRequirementQuery) : ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. + /// Sets context.SsoRequired to indicate whether SSO is required. + /// If SSO is required, sets the validation error result and custom response in the context. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + public async Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context) + { + context.SsoRequired = await RequireSsoAuthenticationAsync(user, request.GrantType); + + if (!context.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + // If the feature flag RecoveryCodeSupportForSsoRequiredUsers is set to false then this code is unreachable since + // Two Factor validation occurs after SSO validation in that scenario. + if (context.TwoFactorRequired && context.TwoFactorRecoveryRequested) + { + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription); + return false; + } + + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoRequiredDescription); + return false; + } + + /// + /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are + /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. + /// If the GrantType is authorization_code or client_credentials we know the user is trying to login + /// using the SSO flow so they are allowed to continue. + /// + /// user trying to login + /// magic string identifying the grant type requested + /// true if sso required; false if not required or already in process + private async Task RequireSsoAuthenticationAsync(User user, string grantType) + { + if (grantType == OidcConstants.GrantTypes.AuthorizationCode || + grantType == OidcConstants.GrantTypes.ClientCredentials) + { + // SSO is not required for users already using SSO to authenticate which uses the authorization_code grant type, + // or logging-in via API key which is the client_credentials grant type. + // Allow user to continue request validation + return false; + } + + // Check if user belongs to any organization with an active SSO policy + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + ? (await _policyRequirementQuery.GetAsync(user.Id)) + .SsoRequired + : await _policyService.AnyPoliciesApplicableToUserAsync( + user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + + if (ssoRequired) + { + return true; + } + + // Default - SSO is not required + return false; + } + + /// + /// Sets the customResponse in the context with the error result for the SSO validation failure. + /// + /// The validator context to update with error details. + /// The error message to return to the client. + private async Task SetContextCustomResponseSsoErrorAsync(CustomValidatorRequestContext context, string errorMessage) + { + var ssoOrganizationIdentifier = await _userSsoOrganizationIdentifierQuery.GetSsoOrganizationIdentifierAsync(context.User.Id); + + context.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = OidcConstants.TokenErrors.InvalidGrant, + ErrorDescription = errorMessage + }; + + context.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(errorMessage) } + }; + + // Include organization identifier in the response if available + if (!string.IsNullOrEmpty(ssoOrganizationIdentifier)) + { + context.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier] = ssoOrganizationIdentifier; + } + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index 294df1c18d..e4cd60827e 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -38,6 +38,7 @@ public class WebAuthnGrantValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -59,6 +60,7 @@ public class WebAuthnGrantValidator : BaseRequestValidator(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient, SendPasswordRequestValidator>(); services.AddTransient, SendEmailOtpRequestValidator>(); diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index b0dec8b415..bc03bb46df 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -231,9 +231,26 @@ public class HubHelpers await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); break; + case PushType.PolicyChanged: + await policyChangedNotificationHandler(notificationJson, cancellationToken); + break; default: _logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); break; } } + + private async Task policyChangedNotificationHandler(string notificationJson, CancellationToken cancellationToken) + { + var policyData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); + if (policyData is null) + { + return; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(policyData.Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, policyData, cancellationToken); + + } } diff --git a/test/Core.IntegrationTest/Core.IntegrationTest.csproj b/test/Core.IntegrationTest/Core.IntegrationTest.csproj index d964452f4c..babe974ffd 100644 --- a/test/Core.IntegrationTest/Core.IntegrationTest.csproj +++ b/test/Core.IntegrationTest/Core.IntegrationTest.csproj @@ -11,11 +11,11 @@ - + - + diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs index b1e3faf257..275466a9bd 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs @@ -6,8 +6,11 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; using Bit.Core.Models.Data.Organizations; +using Bit.Core.Platform.Push; using Bit.Core.Services; using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -95,7 +98,8 @@ public class SavePolicyCommandTests Substitute.For(), [new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()], Substitute.For(), - Substitute.For())); + Substitute.For(), + Substitute.For())); Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message); } @@ -360,6 +364,103 @@ public class SavePolicyCommandTests .ExecuteSideEffectsAsync(default!, default!, default!); } + [Theory, BitAutoData] + public async Task VNextSaveAsync_SendsPushNotification( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + // Arrange + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + var savePolicyModel = new SavePolicyModel(policyUpdate); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + // Act + var result = await sutProvider.Sut.VNextSaveAsync(savePolicyModel); + + // Assert + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_SendsPushNotification([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_SendsPushNotificationWithUpdatedPolicy( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + /// /// Returns a new SutProvider with the PolicyValidators registered in the Sut. /// diff --git a/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs new file mode 100644 index 0000000000..2b448ba79f --- /dev/null +++ b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs @@ -0,0 +1,275 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Auth.UserFeatures.Sso; + +[SutProviderCustomize] +public class UserSsoOrganizationIdentifierQueryTests +{ + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasSingleConfirmedOrganization_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = "test-org-identifier"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("test-org-identifier", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasNoOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId) + { + // Arrange + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(Array.Empty()); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMultipleConfirmedOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser1, + OrganizationUser organizationUser2) + { + // Arrange + organizationUser1.UserId = userId; + organizationUser1.Status = OrganizationUserStatusType.Confirmed; + organizationUser2.UserId = userId; + organizationUser2.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser1, organizationUser2]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserStatusType.Invited)] + [BitAutoData(OrganizationUserStatusType.Accepted)] + [BitAutoData(OrganizationUserStatusType.Revoked)] + public async Task GetSsoOrganizationIdentifierAsync_UserHasOnlyInvitedOrganization_ReturnsNull( + OrganizationUserStatusType status, + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = status; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMixedStatusOrganizations_OnlyOneConfirmed_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser confirmedOrgUser, + OrganizationUser invitedOrgUser, + OrganizationUser revokedOrgUser) + { + // Arrange + confirmedOrgUser.UserId = userId; + confirmedOrgUser.OrganizationId = organization.Id; + confirmedOrgUser.Status = OrganizationUserStatusType.Confirmed; + + invitedOrgUser.UserId = userId; + invitedOrgUser.Status = OrganizationUserStatusType.Invited; + + revokedOrgUser.UserId = userId; + revokedOrgUser.Status = OrganizationUserStatusType.Revoked; + + organization.Identifier = "mixed-status-org"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { confirmedOrgUser, invitedOrgUser, revokedOrgUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("mixed-status-org", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationNotFound_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.OrganizationId) + .Returns((Organization)null); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organizationUser.OrganizationId); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsNull_ReturnsNull( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = null; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsEmpty_ReturnsEmpty( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = string.Empty; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal(string.Empty, result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } +} diff --git a/test/Core.Test/Context/CurrentContextTests.cs b/test/Core.Test/Context/CurrentContextTests.cs index b868d6ceaa..41a54a5b22 100644 --- a/test/Core.Test/Context/CurrentContextTests.cs +++ b/test/Core.Test/Context/CurrentContextTests.cs @@ -107,30 +107,6 @@ public class CurrentContextTests Assert.Equal(deviceType, sutProvider.Sut.DeviceType); } - [Theory, BitAutoData] - public async Task BuildAsync_HttpContext_SetsCloudflareFlags( - SutProvider sutProvider) - { - var httpContext = new DefaultHttpContext(); - var globalSettings = new Core.Settings.GlobalSettings(); - sutProvider.Sut.BotScore = null; - // Arrange - var botScore = 85; - httpContext.Request.Headers["X-Cf-Bot-Score"] = botScore.ToString(); - httpContext.Request.Headers["X-Cf-Worked-Proxied"] = "1"; - httpContext.Request.Headers["X-Cf-Is-Bot"] = "1"; - httpContext.Request.Headers["X-Cf-Maybe-Bot"] = "1"; - - // Act - await sutProvider.Sut.BuildAsync(httpContext, globalSettings); - - // Assert - Assert.True(sutProvider.Sut.CloudflareWorkerProxied); - Assert.True(sutProvider.Sut.IsBot); - Assert.True(sutProvider.Sut.MaybeBot); - Assert.Equal(botScore, sutProvider.Sut.BotScore); - } - [Theory, BitAutoData] public async Task BuildAsync_HttpContext_SetsClientVersion( SutProvider sutProvider) diff --git a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs index d8e944d3b8..a2fc5b19de 100644 --- a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs @@ -74,7 +74,7 @@ public class SendGridMailDeliveryServiceTests : IDisposable Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent); Assert.Equal(mailMessage.TextContent, msg.PlainTextContent); - Assert.Contains("type:Cateogry", msg.Categories); + Assert.Contains("type:Category", msg.Categories); Assert.Contains(msg.Categories, x => x.StartsWith("env:")); Assert.Contains(msg.Categories, x => x.StartsWith("sender:")); diff --git a/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs b/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs index 3063524a57..9dfdf723f3 100644 --- a/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs +++ b/test/Identity.Test/AutoFixture/RequestValidationFixtures.cs @@ -44,14 +44,17 @@ internal class CustomValidatorRequestContextCustomization : ICustomization /// , and /// should initialize false, /// and are made truthy in context upon evaluation of a request. Do not allow AutoFixture to eagerly make these - /// truthy; that is the responsibility of the + /// truthy; that is the responsibility of the . + /// ValidationErrorResult and CustomResponse should also be null initially; they are hydrated during the validation process. /// public void Customize(IFixture fixture) { fixture.Customize(composer => composer .With(o => o.RememberMeRequested, false) .With(o => o.TwoFactorRecoveryRequested, false) - .With(o => o.SsoRequired, false)); + .With(o => o.SsoRequired, false) + .With(o => o.ValidationErrorResult, () => null) + .With(o => o.CustomResponse, () => null)); } } diff --git a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs index e78c7d161c..214fa74ff4 100644 --- a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs @@ -21,6 +21,7 @@ using Bit.Identity.IdentityServer; using Bit.Identity.IdentityServer.RequestValidators; using Bit.Identity.Test.Wrappers; using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; using Duende.IdentityServer.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; @@ -42,6 +43,7 @@ public class BaseRequestValidatorTests private readonly IEventService _eventService; private readonly IDeviceValidator _deviceValidator; private readonly ITwoFactorAuthenticationValidator _twoFactorAuthenticationValidator; + private readonly ISsoRequestValidator _ssoRequestValidator; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly FakeLogger _logger; private readonly ICurrentContext _currentContext; @@ -65,6 +67,7 @@ public class BaseRequestValidatorTests _eventService = Substitute.For(); _deviceValidator = Substitute.For(); _twoFactorAuthenticationValidator = Substitute.For(); + _ssoRequestValidator = Substitute.For(); _organizationUserRepository = Substitute.For(); _logger = new FakeLogger(); _currentContext = Substitute.For(); @@ -85,6 +88,7 @@ public class BaseRequestValidatorTests _eventService, _deviceValidator, _twoFactorAuthenticationValidator, + _ssoRequestValidator, _organizationUserRepository, _logger, _currentContext, @@ -151,6 +155,7 @@ public class BaseRequestValidatorTests // Arrange SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass _sut.isValid = true; @@ -162,9 +167,9 @@ public class BaseRequestValidatorTests // 4 -> set up device validator to fail requestContext.KnownDevice = false; - tokenRequest.GrantType = "password"; + tokenRequest.GrantType = OidcConstants.GrantTypes.Password; _deviceValidator - .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(false)); // 5 -> not legacy user @@ -192,6 +197,7 @@ public class BaseRequestValidatorTests // Arrange SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass _sut.isValid = true; @@ -203,12 +209,13 @@ public class BaseRequestValidatorTests // 4 -> set up device validator to pass _deviceValidator - .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); // 5 -> not legacy user _userService.IsLegacyUser(Arg.Any()) .Returns(false); + _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -236,6 +243,7 @@ public class BaseRequestValidatorTests // Arrange SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass _sut.isValid = true; @@ -262,12 +270,13 @@ public class BaseRequestValidatorTests // 4 -> set up device validator to pass _deviceValidator - .ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); // 5 -> not legacy user _userService.IsLegacyUser(Arg.Any()) .Returns(false); + _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -299,6 +308,7 @@ public class BaseRequestValidatorTests // Arrange SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass _sut.isValid = true; @@ -319,10 +329,19 @@ public class BaseRequestValidatorTests // 2 -> will result to false with no extra configuration // 3 -> set two factor to be required + requestContext.User.TwoFactorProviders = "{\"1\":{\"Enabled\":true,\"MetaData\":{\"Email\":\"user@test.dev\"}}}"; _twoFactorAuthenticationValidator - .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) + .RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(true, null))); + _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(requestContext.User, null) + .Returns(Task.FromResult(new Dictionary + { + { "TwoFactorProviders", new[] { "0", "1" } }, + { "TwoFactorProviders2", new Dictionary{{"Email", null}} } + })); + // Act await _sut.ValidateAsync(context); @@ -330,7 +349,10 @@ public class BaseRequestValidatorTests Assert.True(context.GrantResult.IsError); // Assert that the auth request was NOT consumed - await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + await _authRequestRepository.DidNotReceive().ReplaceAsync(authRequest); + + // Assert that the error is for 2fa + Assert.Equal("Two-factor authentication required.", context.GrantResult.ErrorDescription); } [Theory] @@ -420,6 +442,7 @@ public class BaseRequestValidatorTests { "TwoFactorProviders", new[] { "0", "1" } }, { "TwoFactorProviders2", new Dictionary() } }; + _twoFactorAuthenticationValidator .BuildTwoFactorResultAsync(user, null) .Returns(Task.FromResult(twoFactorResultDict)); @@ -428,6 +451,8 @@ public class BaseRequestValidatorTests await _sut.ValidateAsync(context); // Assert + Assert.Equal("Two-factor authentication required.", context.GrantResult.ErrorDescription); + // Verify that the failed 2FA email was NOT sent for remember token expiration await _mailService.DidNotReceive() .SendFailedTwoFactorAttemptEmailAsync(Arg.Any(), Arg.Any(), @@ -1243,6 +1268,343 @@ public class BaseRequestValidatorTests } } + /// + /// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used. + /// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement + /// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach. + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation( + bool recoveryCodeFeatureEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); + + var context = CreateContext(tokenRequest, requestContext, grantResult); + _sut.isValid = true; + + tokenRequest.GrantType = OidcConstants.GrantTypes.Password; + + // SSO is required via legacy path + _policyService.AnyPoliciesApplicableToUserAsync( + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError); + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + Assert.Equal("SSO authentication is required.", errorResponse.Message); + + // Verify legacy path was used + await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( + requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + + // Verify new SsoRequestValidator was NOT called + await _ssoRequestValidator.DidNotReceive().ValidateAsync( + Arg.Any(), Arg.Any(), Arg.Any()); + } + + /// + /// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used + /// instead of the legacy RequireSsoLoginAsync method. + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator( + bool recoveryCodeFeatureEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); + + var context = CreateContext(tokenRequest, requestContext, grantResult); + _sut.isValid = true; + + tokenRequest.GrantType = OidcConstants.GrantTypes.Password; + + // Configure SsoRequestValidator to indicate SSO is required + _ssoRequestValidator.ValidateAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(false)); // false = SSO required + + // Set up the ValidationErrorResult that SsoRequestValidator would set + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = "sso_required", + ErrorDescription = "SSO authentication is required." + }; + requestContext.CustomResponse = new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }; + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError); + + // Verify new SsoRequestValidator was called + await _ssoRequestValidator.Received(1).ValidateAsync( + requestContext.User, + tokenRequest, + requestContext); + + // Verify legacy path was NOT used + await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( + Arg.Any(), Arg.Any(), Arg.Any()); + } + + /// + /// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required, + /// authentication continues successfully through the new validation path. + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin( + bool recoveryCodeFeatureEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); + + var context = CreateContext(tokenRequest, requestContext, grantResult); + _sut.isValid = true; + + tokenRequest.GrantType = OidcConstants.GrantTypes.Password; + tokenRequest.ClientId = "web"; + + // SsoRequestValidator returns true (SSO not required) + _ssoRequestValidator.ValidateAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(true)); + + // No 2FA required + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) + .Returns(Task.FromResult(new Tuple(false, null))); + + // Device validation passes + _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // User is not legacy + _userService.IsLegacyUser(Arg.Any()).Returns(false); + + _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( + "test-private-key", + "test-public-key" + ) + }); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.False(context.GrantResult.IsError); + await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn); + + // Verify new validator was used + await _ssoRequestValidator.Received(1).ValidateAsync( + requestContext.User, + tokenRequest, + requestContext); + } + + /// + /// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response + /// (e.g., with organization identifier), that custom response is properly propagated to the result. + /// + [Theory] + [BitAutoData(true)] + [BitAutoData(false)] + public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse( + bool recoveryCodeFeatureEnabled, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); + _sut.isValid = true; + + tokenRequest.GrantType = OidcConstants.GrantTypes.Password; + + // SsoRequestValidator sets custom response with organization identifier + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = "sso_required", + ErrorDescription = "SSO authentication is required." + }; + requestContext.CustomResponse = new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }, + { "SsoOrganizationIdentifier", "test-org-identifier" } + }; + + var context = CreateContext(tokenRequest, requestContext, grantResult); + + _ssoRequestValidator.ValidateAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(false)); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError); + Assert.NotNull(context.GrantResult.CustomResponse); + Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse); + Assert.Equal("test-org-identifier", context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]); + } + + /// + /// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery, + /// but SSO is required, the legacy error message is returned (without the recovery-specific message). + /// + [Theory] + [BitAutoData] + public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage( + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); + + var context = CreateContext(tokenRequest, requestContext, grantResult); + _sut.isValid = true; + + // Recovery code scenario + tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); + tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code"; + + // 2FA with recovery + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(requestContext.User, tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code") + .Returns(Task.FromResult(true)); + + // SSO is required (legacy check) + _policyService.AnyPoliciesApplicableToUserAsync( + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError); + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + + // Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired + Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); + + // But legacy validation path was used + await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( + requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + } + + /// + /// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user, + /// the SsoRequestValidator provides the recovery-specific error message. + /// + [Theory] + [BitAutoData] + public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage( + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true); + _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); + + var context = CreateContext(tokenRequest, requestContext, grantResult); + _sut.isValid = true; + + // Recovery code scenario + tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); + tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code"; + + // 2FA with recovery + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(requestContext.User, tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code") + .Returns(Task.FromResult(true)); + + // SsoRequestValidator handles the recovery + SSO scenario + requestContext.TwoFactorRecoveryRequested = true; + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = "sso_required", + ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required." + }; + requestContext.CustomResponse = new Dictionary + { + { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } + }; + + _ssoRequestValidator.ValidateAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(false)); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.True(context.GrantResult.IsError); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"]; + Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); + + // Verify new validator was used + await _ssoRequestValidator.Received(1).ValidateAsync( + requestContext.User, + tokenRequest, + Arg.Is(ctx => ctx.TwoFactorRecoveryRequested)); + + // Verify legacy path was NOT used + await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( + Arg.Any(), Arg.Any(), Arg.Any()); + } + private BaseRequestValidationContextFake CreateContext( ValidatedTokenRequest tokenRequest, CustomValidatorRequestContext requestContext, diff --git a/test/Identity.Test/IdentityServer/SsoRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/SsoRequestValidatorTests.cs new file mode 100644 index 0000000000..2875b5bd37 --- /dev/null +++ b/test/Identity.Test/IdentityServer/SsoRequestValidatorTests.cs @@ -0,0 +1,469 @@ +using Bit.Core; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Services; +using Bit.Identity.IdentityServer; +using Bit.Identity.IdentityServer.Enums; +using Bit.Identity.IdentityServer.RequestValidationConstants; +using Bit.Identity.IdentityServer.RequestValidators; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; +using Duende.IdentityServer.Validation; +using NSubstitute; +using Xunit; +using AuthFixtures = Bit.Identity.Test.AutoFixture; + +namespace Bit.Identity.Test.IdentityServer; + +[SutProviderCustomize] +public class SsoRequestValidatorTests +{ + + [Theory] + [BitAutoData(OidcConstants.GrantTypes.AuthorizationCode)] + [BitAutoData(OidcConstants.GrantTypes.ClientCredentials)] + public async void ValidateAsync_GrantTypeIgnoresSsoRequirement_ReturnsTrue( + string grantType, + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = grantType; + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.True(result); + Assert.False(context.SsoRequired); + Assert.Null(context.ValidationErrorResult); + Assert.Null(context.CustomResponse); + + // Should not check policies since grant type allows bypass + await sutProvider.GetDependency().DidNotReceive() + .AnyPoliciesApplicableToUserAsync(Arg.Any(), Arg.Any(), Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .GetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoNotRequired_RequirementPolicyFeatureFlagEnabled_ReturnsTrue( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = false }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.True(result); + Assert.False(context.SsoRequired); + Assert.Null(context.ValidationErrorResult); + Assert.Null(context.CustomResponse); + + // Should use the new policy requirement query when feature flag is enabled + await sutProvider.GetDependency().Received(1).GetAsync(user.Id); + await sutProvider.GetDependency().DidNotReceive() + .AnyPoliciesApplicableToUserAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoNotRequired_RequirementPolicyFeatureFlagDisabled_ReturnsTrue( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); + + sutProvider.GetDependency().AnyPoliciesApplicableToUserAsync( + user.Id, + PolicyType.RequireSso, + OrganizationUserStatusType.Confirmed) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.True(result); + Assert.False(context.SsoRequired); + Assert.Null(context.ValidationErrorResult); + Assert.Null(context.CustomResponse); + + // Should use the legacy policy service when feature flag is disabled + await sutProvider.GetDependency().Received(1).AnyPoliciesApplicableToUserAsync( + user.Id, + PolicyType.RequireSso, + OrganizationUserStatusType.Confirmed); + await sutProvider.GetDependency().DidNotReceive() + .GetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_RequirementPolicyFeatureFlagEnabled_ReturnsFalse( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.ValidationErrorResult); + Assert.True(context.ValidationErrorResult.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error); + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription); + + Assert.NotNull(context.CustomResponse); + Assert.True(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.ErrorModel)); + Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier)); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_RequirementPolicyFeatureFlagDisabled_ReturnsFalse( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); + + sutProvider.GetDependency().AnyPoliciesApplicableToUserAsync( + user.Id, + PolicyType.RequireSso, + OrganizationUserStatusType.Confirmed) + .Returns(true); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.ValidationErrorResult); + Assert.True(context.ValidationErrorResult.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error); + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription); + + Assert.NotNull(context.CustomResponse); + Assert.True(context.CustomResponse.ContainsKey("ErrorModel")); + Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier")); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_TwoFactorRecoveryRequested_ReturnsFalse_WithSpecialMessage( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + context.TwoFactorRecoveryRequested = true; + context.TwoFactorRequired = true; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.ValidationErrorResult); + Assert.True(context.ValidationErrorResult.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error); + Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", + context.ValidationErrorResult.ErrorDescription); + + Assert.NotNull(context.CustomResponse); + Assert.True(context.CustomResponse.ContainsKey("ErrorModel")); + Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier")); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_TwoFactorRequiredButNotRecovery_ReturnsFalse_WithStandardMessage( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.ValidationErrorResult); + Assert.True(context.ValidationErrorResult.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error); + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription); + + Assert.NotNull(context.CustomResponse); + Assert.True(context.CustomResponse.ContainsKey("ErrorModel")); + Assert.False(context.CustomResponse.ContainsKey("SsoOrganizationIdentifier")); + } + + [Theory] + [BitAutoData(OidcConstants.GrantTypes.Password)] + [BitAutoData(OidcConstants.GrantTypes.RefreshToken)] + [BitAutoData(CustomGrantTypes.WebAuthn)] + public async void ValidateAsync_VariousGrantTypes_SsoRequired_ReturnsFalse( + string grantType, + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = grantType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.ValidationErrorResult); + Assert.True(context.ValidationErrorResult.IsError); + Assert.Equal(OidcConstants.TokenErrors.InvalidGrant, context.ValidationErrorResult.Error); + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, context.ValidationErrorResult.ErrorDescription); + Assert.NotNull(context.CustomResponse); + } + + [Theory, BitAutoData] + public async void ValidateAsync_ContextSsoRequiredUpdated_RegardlessOfInitialValue( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + context.SsoRequired = true; // Start with true to ensure it gets updated + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = false }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.True(result); + Assert.False(context.SsoRequired); // Should be updated to false + Assert.Null(context.ValidationErrorResult); + Assert.Null(context.CustomResponse); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_WithOrganizationIdentifier_IncludesIdentifierInResponse( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + const string orgIdentifier = "test-organization"; + request.GrantType = OidcConstants.GrantTypes.Password; + context.User = user; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns(orgIdentifier); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.CustomResponse); + Assert.True(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier)); + Assert.Equal(orgIdentifier, context.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]); + + await sutProvider.GetDependency() + .Received(1) + .GetSsoOrganizationIdentifierAsync(user.Id); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_NoOrganizationIdentifier_DoesNotIncludeIdentifierInResponse( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + context.User = user; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns((string)null); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.CustomResponse); + Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier)); + + await sutProvider.GetDependency() + .Received(1) + .GetSsoOrganizationIdentifierAsync(user.Id); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoRequired_EmptyOrganizationIdentifier_DoesNotIncludeIdentifierInResponse( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + context.User = user; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + sutProvider.GetDependency() + .GetSsoOrganizationIdentifierAsync(user.Id) + .Returns(string.Empty); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.False(result); + Assert.True(context.SsoRequired); + Assert.NotNull(context.CustomResponse); + Assert.False(context.CustomResponse.ContainsKey(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier)); + + await sutProvider.GetDependency() + .Received(1) + .GetSsoOrganizationIdentifierAsync(user.Id); + } + + [Theory, BitAutoData] + public async void ValidateAsync_SsoNotRequired_DoesNotCallOrganizationIdentifierQuery( + User user, + [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext context, + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest request, + SutProvider sutProvider) + { + // Arrange + request.GrantType = OidcConstants.GrantTypes.Password; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + var requirement = new RequireSsoPolicyRequirement { SsoRequired = false }; + sutProvider.GetDependency().GetAsync(user.Id) + .Returns(requirement); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, request, context); + + // Assert + Assert.True(result); + Assert.False(context.SsoRequired); + + await sutProvider.GetDependency() + .DidNotReceive() + .GetSsoOrganizationIdentifierAsync(Arg.Any()); + } +} diff --git a/test/Identity.Test/IdentityServer/TwoFactorAuthenticationValidatorTests.cs b/test/Identity.Test/IdentityServer/TwoFactorAuthenticationValidatorTests.cs index 53e9a00c9f..c4cbd4b796 100644 --- a/test/Identity.Test/IdentityServer/TwoFactorAuthenticationValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/TwoFactorAuthenticationValidatorTests.cs @@ -32,7 +32,7 @@ public class TwoFactorAuthenticationValidatorTests private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationRepository _organizationRepository; private readonly IDataProtectorTokenFactory _ssoEmail2faSessionTokenable; - private readonly ITwoFactorIsEnabledQuery _twoFactorenabledQuery; + private readonly ITwoFactorIsEnabledQuery _twoFactorEnabledQuery; private readonly ICurrentContext _currentContext; private readonly TwoFactorAuthenticationValidator _sut; @@ -45,7 +45,7 @@ public class TwoFactorAuthenticationValidatorTests _organizationUserRepository = Substitute.For(); _organizationRepository = Substitute.For(); _ssoEmail2faSessionTokenable = Substitute.For>(); - _twoFactorenabledQuery = Substitute.For(); + _twoFactorEnabledQuery = Substitute.For(); _currentContext = Substitute.For(); _sut = new TwoFactorAuthenticationValidator( @@ -56,7 +56,7 @@ public class TwoFactorAuthenticationValidatorTests _organizationUserRepository, _organizationRepository, _ssoEmail2faSessionTokenable, - _twoFactorenabledQuery, + _twoFactorEnabledQuery, _currentContext); } diff --git a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs index ec3e791d5b..b336e4c3c1 100644 --- a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs +++ b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs @@ -54,6 +54,7 @@ IBaseRequestValidatorTestWrapper IEventService eventService, IDeviceValidator deviceValidator, ITwoFactorAuthenticationValidator twoFactorAuthenticationValidator, + ISsoRequestValidator ssoRequestValidator, IOrganizationUserRepository organizationUserRepository, ILogger logger, ICurrentContext currentContext, @@ -73,6 +74,7 @@ IBaseRequestValidatorTestWrapper eventService, deviceValidator, twoFactorAuthenticationValidator, + ssoRequestValidator, organizationUserRepository, logger, currentContext, @@ -132,12 +134,17 @@ IBaseRequestValidatorTestWrapper protected override void SetTwoFactorResult( BaseRequestValidationContextFake context, Dictionary customResponse) - { } + { + context.GrantResult = new GrantValidationResult( + TokenRequestErrors.InvalidGrant, "Two-factor authentication required.", customResponse); + } protected override void SetValidationErrorResult( BaseRequestValidationContextFake context, CustomValidatorRequestContext requestContext) - { } + { + context.GrantResult.IsError = true; + } protected override Task ValidateContextAsync( BaseRequestValidationContextFake context, diff --git a/test/Notifications.Test/HubHelpersTest.cs b/test/Notifications.Test/HubHelpersTest.cs index df4d3c5f85..2cd20858f3 100644 --- a/test/Notifications.Test/HubHelpersTest.cs +++ b/test/Notifications.Test/HubHelpersTest.cs @@ -225,6 +225,30 @@ public class HubHelpersTest .Group(Arg.Any()); } + [Theory] + [BitAutoData] + public async Task SendNotificationToHubAsync_PolicyChanged_SentToOrganizationGroup( + SutProvider sutProvider, + SyncPolicyPushNotification notification, + string contextId, + CancellationToken cancellationToken) + { + var json = ToNotificationJson(notification, PushType.PolicyChanged, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"Organization_{notification.OrganizationId}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && AssertSyncPolicyPushNotification(notification, objects[0], + PushType.PolicyChanged, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + private static string ToNotificationJson(object payload, PushType type, string contextId) { var notification = new PushNotificationData(type, payload, contextId); @@ -247,4 +271,20 @@ public class HubHelpersTest expected.ClientType == pushNotificationData.Payload.ClientType && expected.RevisionDate == pushNotificationData.Payload.RevisionDate; } + + private static bool AssertSyncPolicyPushNotification(SyncPolicyPushNotification expected, object? actual, + PushType type, string contextId) + { + if (actual is not PushNotificationData pushNotificationData) + { + return false; + } + + return pushNotificationData.Type == type && + pushNotificationData.ContextId == contextId && + expected.OrganizationId == pushNotificationData.Payload.OrganizationId && + expected.Policy.Id == pushNotificationData.Payload.Policy.Id && + expected.Policy.Type == pushNotificationData.Payload.Policy.Type && + expected.Policy.Enabled == pushNotificationData.Payload.Policy.Enabled; + } } diff --git a/util/DbSeederUtility/DbSeederUtility.csproj b/util/DbSeederUtility/DbSeederUtility.csproj index 90ac7f22b4..f6195a6763 100644 --- a/util/DbSeederUtility/DbSeederUtility.csproj +++ b/util/DbSeederUtility/DbSeederUtility.csproj @@ -16,7 +16,7 @@ - + diff --git a/util/MsSqlMigratorUtility/MsSqlMigratorUtility.csproj b/util/MsSqlMigratorUtility/MsSqlMigratorUtility.csproj index d316e56161..7e68a91b65 100644 --- a/util/MsSqlMigratorUtility/MsSqlMigratorUtility.csproj +++ b/util/MsSqlMigratorUtility/MsSqlMigratorUtility.csproj @@ -10,7 +10,7 @@ - +