diff --git a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs index a53af41e66..a709a47cb2 100644 --- a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core.Auth.Entities; using Bit.Core.Entities; using Duende.IdentityServer.Validation; @@ -41,4 +42,10 @@ public class CustomValidatorRequestContext /// This will be null if the authentication request is successful. /// public Dictionary CustomResponse { get; set; } + + /// + /// A validated auth request + /// + /// + public AuthRequest ValidatedAuthRequest { get; set; } } diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index 0b33dabb77..3317e18264 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -35,6 +35,7 @@ public abstract class BaseRequestValidator where T : class private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; private readonly IUserRepository _userRepository; + private readonly IAuthRequestRepository _authRequestRepository; protected ICurrentContext CurrentContext { get; } protected IPolicyService PolicyService { get; } @@ -59,7 +60,9 @@ public abstract class BaseRequestValidator where T : class IFeatureService featureService, ISsoConfigRepository ssoConfigRepository, IUserDecryptionOptionsBuilder userDecryptionOptionsBuilder, - IPolicyRequirementQuery policyRequirementQuery) + IPolicyRequirementQuery policyRequirementQuery, + IAuthRequestRepository authRequestRepository + ) { _userManager = userManager; _userService = userService; @@ -76,6 +79,7 @@ public abstract class BaseRequestValidator where T : class SsoConfigRepository = ssoConfigRepository; UserDecryptionOptionsBuilder = userDecryptionOptionsBuilder; PolicyRequirementQuery = policyRequirementQuery; + _authRequestRepository = authRequestRepository; } protected async Task ValidateAsync(T context, ValidatedTokenRequest request, @@ -190,6 +194,14 @@ public abstract class BaseRequestValidator where T : class return; } + // TODO: PM-24324 - This should be its own validator at some point. + // 6. Auth request handling + if (validatorContext.ValidatedAuthRequest != null) + { + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); + } + await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); } @@ -404,8 +416,8 @@ public abstract class BaseRequestValidator where T : class /// /// Builds the custom response that will be sent to the client upon successful authentication, which /// includes the information needed for the client to initialize the user's account in state. - /// - /// The authenticated user. + /// + /// The authenticated user. /// The current request context. /// The device used for authentication. /// Whether to send a 2FA remember token. diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 5042f38b4f..c3d7908dc9 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -45,7 +45,8 @@ public class CustomTokenRequestValidator : BaseRequestValidator(); _userDecryptionOptionsBuilder = Substitute.For(); _policyRequirementQuery = Substitute.For(); + _authRequestRepository = Substitute.For(); _sut = new BaseRequestValidatorTestWrapper( _userManager, @@ -84,7 +87,8 @@ public class BaseRequestValidatorTests _featureService, _ssoConfigRepository, _userDecryptionOptionsBuilder, - _policyRequirementQuery); + _policyRequirementQuery, + _authRequestRepository); } /* Logic path @@ -181,6 +185,99 @@ public class BaseRequestValidatorTests Assert.False(context.GrantResult.IsError); } + [Theory, BitAutoData] + public async Task ValidateAsync_ValidatedAuthRequest_ConsumedOnSuccess( + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass + _sut.isValid = true; + + var authRequest = new AuthRequest + { + Type = AuthRequestType.AuthenticateAndUnlock, + RequestDeviceIdentifier = "", + RequestIpAddress = "1.1.1.1", + AccessCode = "password", + PublicKey = "test_public_key", + CreationDate = DateTime.UtcNow.AddMinutes(-5), + ResponseDate = DateTime.UtcNow.AddMinutes(-2), + Approved = true, + AuthenticationDate = null, // unused + UserId = requestContext.User.Id, + }; + requestContext.ValidatedAuthRequest = authRequest; + + // 2 -> will result to false with no extra configuration + // 3 -> set two factor to be false + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) + .Returns(Task.FromResult(new Tuple(false, null))); + + // 4 -> set up device validator to pass + _deviceValidator.ValidateRequestDeviceAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); + + // 5 -> not legacy user + _userService.IsLegacyUser(Arg.Any()) + .Returns(false); + + // Act + await _sut.ValidateAsync(context); + + // Assert + Assert.False(context.GrantResult.IsError); + + // Check that the auth request was consumed + await _authRequestRepository.Received(1).ReplaceAsync(Arg.Is(ar => + ar.AuthenticationDate.HasValue)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_ValidatedAuthRequest_NotConsumed_When2faRequired( + [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, + CustomValidatorRequestContext requestContext, + GrantValidationResult grantResult) + { + // Arrange + var context = CreateContext(tokenRequest, requestContext, grantResult); + // 1 -> to pass + _sut.isValid = true; + + var authRequest = new AuthRequest + { + Type = AuthRequestType.AuthenticateAndUnlock, + RequestDeviceIdentifier = "", + RequestIpAddress = "1.1.1.1", + AccessCode = "password", + PublicKey = "test_public_key", + CreationDate = DateTime.UtcNow.AddMinutes(-5), + ResponseDate = DateTime.UtcNow.AddMinutes(-2), + Approved = true, + AuthenticationDate = null, // unused + UserId = requestContext.User.Id, + }; + requestContext.ValidatedAuthRequest = authRequest; + + // 2 -> will result to false with no extra configuration + // 3 -> set two factor to be required + _twoFactorAuthenticationValidator + .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) + .Returns(Task.FromResult(new Tuple(true, null))); + + // Act + await _sut.ValidateAsync(context); + + // Assert we errored for 2fa requirement + Assert.True(context.GrantResult.IsError); + + // Assert that the auth request was NOT consumed + await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + // Test grantTypes that require SSO when a user is in an organization that requires it [Theory] [BitAutoData("password")] diff --git a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs index 4c14de2d73..140e171309 100644 --- a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs +++ b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs @@ -62,7 +62,8 @@ IBaseRequestValidatorTestWrapper IFeatureService featureService, ISsoConfigRepository ssoConfigRepository, IUserDecryptionOptionsBuilder userDecryptionOptionsBuilder, - IPolicyRequirementQuery policyRequirementQuery) : + IPolicyRequirementQuery policyRequirementQuery, + IAuthRequestRepository authRequestRepository) : base( userManager, userService, @@ -78,7 +79,8 @@ IBaseRequestValidatorTestWrapper featureService, ssoConfigRepository, userDecryptionOptionsBuilder, - policyRequirementQuery) + policyRequirementQuery, + authRequestRepository) { }