1
0
mirror of https://github.com/bitwarden/server synced 2025-12-26 21:23:39 +00:00

Merge remote-tracking branch 'origin/main' into arch/seeder-api

This commit is contained in:
Matt Gibson
2025-11-04 21:43:51 -08:00
258 changed files with 23154 additions and 2234 deletions

View File

@@ -0,0 +1,197 @@
using System.Net;
using Bit.Api.AdminConsole.Authorization;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.Models.Request.Organizations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Api;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Xunit;
namespace Bit.Api.IntegrationTest.AdminConsole.Controllers;
public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture<ApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly ApiApplicationFactory _factory;
private readonly LoginHelper _loginHelper;
private Organization _organization = null!;
private string _ownerEmail = null!;
public OrganizationUsersControllerPutResetPasswordTests(ApiApplicationFactory apiFactory)
{
_factory = apiFactory;
_factory.SubstituteService<IFeatureService>(featureService =>
{
featureService
.IsEnabled(FeatureFlagKeys.AccountRecoveryCommand)
.Returns(true);
});
_client = _factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}
public async Task InitializeAsync()
{
_ownerEmail = $"reset-password-test-{Guid.NewGuid()}@example.com";
await _factory.LoginWithNewAccount(_ownerEmail);
(_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023,
ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card);
// Enable reset password and policies for the organization
var organizationRepository = _factory.GetService<IOrganizationRepository>();
_organization.UseResetPassword = true;
_organization.UsePolicies = true;
await organizationRepository.ReplaceAsync(_organization);
// Enable the ResetPassword policy
var policyRepository = _factory.GetService<IPolicyRepository>();
await policyRepository.CreateAsync(new Policy
{
OrganizationId = _organization.Id,
Type = PolicyType.ResetPassword,
Enabled = true,
Data = "{}"
});
}
public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}
/// <summary>
/// Helper method to set the ResetPasswordKey on an organization user, which is required for account recovery
/// </summary>
private async Task SetResetPasswordKeyAsync(OrganizationUser orgUser)
{
var organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
orgUser.ResetPasswordKey = "encrypted-reset-password-key";
await organizationUserRepository.ReplaceAsync(orgUser);
}
[Fact]
public async Task PutResetPassword_AsHigherRole_CanRecoverLowerRole()
{
// Arrange
var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Owner);
await _loginHelper.LoginAsync(ownerEmail);
var (_, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.User);
await SetResetPasswordKeyAsync(targetOrgUser);
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
[Fact]
public async Task PutResetPassword_AsLowerRole_CannotRecoverHigherRole()
{
// Arrange
var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Admin);
await _loginHelper.LoginAsync(adminEmail);
var (_, targetOwnerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.Owner);
await SetResetPasswordKeyAsync(targetOwnerOrgUser);
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOwnerOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var model = await response.Content.ReadFromJsonAsync<ErrorResponseModel>();
Assert.Contains(RecoverAccountAuthorizationHandler.FailureReason, model.Message);
}
[Fact]
public async Task PutResetPassword_CannotRecoverProviderAccount()
{
// Arrange - Create owner who will try to recover the provider account
var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Owner);
await _loginHelper.LoginAsync(ownerEmail);
// Create a user who is also a provider user
var (targetUserEmail, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.User);
await SetResetPasswordKeyAsync(targetOrgUser);
// Add the target user as a provider user to a different provider
var providerRepository = _factory.GetService<IProviderRepository>();
var providerUserRepository = _factory.GetService<IProviderUserRepository>();
var userRepository = _factory.GetService<IUserRepository>();
var provider = await providerRepository.CreateAsync(new Provider
{
Name = "Test Provider",
BusinessName = "Test Provider Business",
BillingEmail = "provider@example.com",
Type = ProviderType.Msp,
Status = ProviderStatusType.Created,
Enabled = true
});
var targetUser = await userRepository.GetByEmailAsync(targetUserEmail);
Assert.NotNull(targetUser);
await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = targetUser.Id,
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin
});
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var model = await response.Content.ReadFromJsonAsync<ErrorResponseModel>();
Assert.Equal(RecoverAccountAuthorizationHandler.ProviderFailureReason, model.Message);
}
}

View File

@@ -211,4 +211,200 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
}
}
[Fact]
public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minLength", "not a number" }, // Wrong type - should be int
{ "requireUpper", true }
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Contains("minLength", content); // Verify field name is in error message
}
[Fact]
public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task PutVNext_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minComplexity", "not a number" }, // Wrong type - should be int
{ "minLength", 12 }
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Contains("minComplexity", content); // Verify field name is in error message
}
[Fact]
public async Task PutVNext_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task PutVNext_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.SingleOrg;
var request = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = null
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
[Fact]
public async Task PutVNext_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.TwoFactorAuthentication;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = null
},
Metadata = null
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
}

View File

@@ -64,6 +64,17 @@ public class MembersControllerTests : IClassFixture<ApiApplicationFactory>, IAsy
var (userEmail4, orgUser4) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id,
OrganizationUserType.Admin);
var collection1 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 1", users:
[
new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = false, Manage = true },
new CollectionAccessSelection { Id = orgUser3.Id, ReadOnly = true, HidePasswords = false, Manage = false }
]);
var collection2 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 2", users:
[
new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = true, Manage = false }
]);
var response = await _client.GetAsync($"/public/members");
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var result = await response.Content.ReadFromJsonAsync<ListResponseModel<MemberResponseModel>>();
@@ -71,23 +82,47 @@ public class MembersControllerTests : IClassFixture<ApiApplicationFactory>, IAsy
Assert.Equal(5, result.Data.Count());
// The owner
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner));
var ownerResult = result.Data.SingleOrDefault(m => m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner);
Assert.NotNull(ownerResult);
Assert.Empty(ownerResult.Collections);
// The custom user
// The custom user with collections
var user1Result = result.Data.Single(m => m.Email == userEmail1);
Assert.Equal(OrganizationUserType.Custom, user1Result.Type);
AssertHelper.AssertPropertyEqual(
new PermissionsModel { AccessImportExport = true, ManagePolicies = true, AccessReports = true },
user1Result.Permissions);
// Verify collections
Assert.NotNull(user1Result.Collections);
Assert.Equal(2, user1Result.Collections.Count());
var user1Collection1 = user1Result.Collections.Single(c => c.Id == collection1.Id);
Assert.False(user1Collection1.ReadOnly);
Assert.False(user1Collection1.HidePasswords);
Assert.True(user1Collection1.Manage);
var user1Collection2 = user1Result.Collections.Single(c => c.Id == collection2.Id);
Assert.False(user1Collection2.ReadOnly);
Assert.True(user1Collection2.HidePasswords);
Assert.False(user1Collection2.Manage);
// Everyone else
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail2 && m.Type == OrganizationUserType.Owner));
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail3 && m.Type == OrganizationUserType.User));
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail4 && m.Type == OrganizationUserType.Admin));
// The other owner
var user2Result = result.Data.SingleOrDefault(m => m.Email == userEmail2 && m.Type == OrganizationUserType.Owner);
Assert.NotNull(user2Result);
Assert.Empty(user2Result.Collections);
// The user with one collection
var user3Result = result.Data.SingleOrDefault(m => m.Email == userEmail3 && m.Type == OrganizationUserType.User);
Assert.NotNull(user3Result);
Assert.NotNull(user3Result.Collections);
Assert.Single(user3Result.Collections);
var user3Collection1 = user3Result.Collections.Single(c => c.Id == collection1.Id);
Assert.True(user3Collection1.ReadOnly);
Assert.False(user3Collection1.HidePasswords);
Assert.False(user3Collection1.Manage);
// The admin with no collections
var user4Result = result.Data.SingleOrDefault(m => m.Email == userEmail4 && m.Type == OrganizationUserType.Admin);
Assert.NotNull(user4Result);
Assert.Empty(user4Result.Collections);
}
[Fact]

View File

@@ -160,4 +160,86 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
Assert.Equal(15, data.MinLength);
Assert.Equal(true, data.RequireUpper);
}
[Fact]
public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minLength", "not a number" }, // Wrong type - should be int
{ "requireUpper", true }
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.DisableSend;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = null
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
}

View File

@@ -151,6 +151,28 @@ public static class OrganizationTestHelpers
return group;
}
/// <summary>
/// Creates a collection with optional user and group associations.
/// </summary>
public static async Task<Collection> CreateCollectionAsync(
ApiApplicationFactory factory,
Guid organizationId,
string name,
IEnumerable<CollectionAccessSelection>? users = null,
IEnumerable<CollectionAccessSelection>? groups = null)
{
var collectionRepository = factory.GetService<ICollectionRepository>();
var collection = new Collection
{
OrganizationId = organizationId,
Name = name,
Type = CollectionType.SharedCollection
};
await collectionRepository.CreateAsync(collection, groups, users);
return collection;
}
/// <summary>
/// Enables the Organization Data Ownership policy for the specified organization.
/// </summary>

View File

@@ -0,0 +1,296 @@
using System.Security.Claims;
using Bit.Api.AdminConsole.Authorization;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Test.AutoFixture.OrganizationUserFixtures;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Authorization;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Authorization;
[SutProviderCustomize]
public class RecoverAccountAuthorizationHandlerTests
{
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsProvider_TargetUserNotProvider_Authorized(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null);
MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsNotMemberOrProvider_NotAuthorized(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason);
}
// Pairing of CurrentContextOrganization (current user permissions) and target user role
// Read this as: a ___ can recover the account for a ___
public static IEnumerable<object[]> AuthorizedRoleCombinations => new object[][]
{
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.User],
};
[Theory, BitMemberAutoData(nameof(AuthorizedRoleCombinations))]
public async Task AuthorizeMemberAsync_RecoverEqualOrLesserRoles_TargetUserNotProvider_Authorized(
CurrentContextOrganization currentContextOrganization,
OrganizationUserType targetOrganizationUserType,
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.Type = targetOrganizationUserType;
currentContextOrganization.Id = targetOrganizationUser.OrganizationId;
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
// Pairing of CurrentContextOrganization (current user permissions) and target user role
// Read this as: a ___ cannot recover the account for a ___
public static IEnumerable<object[]> UnauthorizedRoleCombinations => new object[][]
{
// These roles should fail because you cannot recover a greater role
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true} }, OrganizationUserType.Admin],
// These roles are never authorized to recover any account
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.User],
};
[Theory, BitMemberAutoData(nameof(UnauthorizedRoleCombinations))]
public async Task AuthorizeMemberAsync_InvalidRoles_TargetUserNotProvider_Unauthorized(
CurrentContextOrganization currentContextOrganization,
OrganizationUserType targetOrganizationUserType,
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.Type = targetOrganizationUserType;
currentContextOrganization.Id = targetOrganizationUser.OrganizationId;
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_TargetUserIdIsNull_DoesNotBlock(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.UserId = null;
MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser);
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
// This should shortcut the provider escalation check
await sutProvider.GetDependency<IProviderUserRepository>().DidNotReceiveWithAnyArgs()
.GetManyByUserAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsMemberOfAllTargetUserProviders_DoesNotBlock(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal,
Guid providerId1,
Guid providerId2)
{
// Arrange
var targetUserProviders = new List<ProviderUser>
{
new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId },
new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId }
};
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByUserAsync(targetOrganizationUser.UserId!.Value)
.Returns(targetUserProviders);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId1)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId2)
.Returns(true);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserMissingProviderMembership_Blocks(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal,
Guid providerId1,
Guid providerId2)
{
// Arrange
var targetUserProviders = new List<ProviderUser>
{
new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId },
new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId }
};
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByUserAsync(targetOrganizationUser.UserId!.Value)
.Returns(targetUserProviders);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId1)
.Returns(true);
// Not a member of this provider
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId2)
.Returns(false);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.ProviderFailureReason);
}
private static void MockOrganizationClaims(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser,
CurrentContextOrganization? currentContextOrganization)
{
sutProvider.GetDependency<IOrganizationContext>()
.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId)
.Returns(currentContextOrganization);
}
private static void MockCurrentUserIsProvider(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser)
{
sutProvider.GetDependency<IOrganizationContext>()
.IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId)
.Returns(true);
}
private static void MockCurrentUserIsOwner(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser)
{
var currentContextOrganization = new CurrentContextOrganization
{
Id = targetOrganizationUser.OrganizationId,
Type = OrganizationUserType.Owner
};
sutProvider.GetDependency<IOrganizationContext>()
.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId)
.Returns(currentContextOrganization);
}
private static void AssertFailed(AuthorizationHandlerContext context, string expectedMessage)
{
Assert.True(context.HasFailed);
var failureReason = Assert.Single(context.FailureReasons);
Assert.Equal(expectedMessage, failureReason.Message);
}
}

View File

@@ -1,11 +1,14 @@
using System.Security.Claims;
using Bit.Api.AdminConsole.Authorization;
using Bit.Api.AdminConsole.Controllers;
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Models.Request.Organizations;
using Bit.Api.Vault.AuthorizationHandlers.Collections;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
@@ -16,6 +19,7 @@ using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Api;
using Bit.Core.Models.Business;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations;
@@ -30,6 +34,7 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using NSubstitute;
using Xunit;
@@ -440,4 +445,153 @@ public class OrganizationUsersControllerTests
Assert.Equal("Master Password reset is required, but not provided.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_CallsLegacyPath(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(true);
sutProvider.GetDependency<IUserService>().AdminResetPasswordAsync(Arg.Any<OrganizationUserType>(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<IUserService>().Received(1)
.AdminResetPasswordAsync(OrganizationUserType.Owner, orgId, orgUserId, model.NewMasterPasswordHash, model.Key);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_WhenOrgUserTypeIsNull_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(false);
sutProvider.GetDependency<ICurrentContext>().Organizations.Returns(new List<CurrentContextOrganization>());
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_WhenAdminResetPasswordFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(true);
sutProvider.GetDependency<IUserService>().AdminResetPasswordAsync(Arg.Any<OrganizationUserType>(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error 1" }));
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ModelStateDictionary>>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationUserNotFound_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns((OrganizationUser)null);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationIdMismatch_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = Guid.NewGuid();
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenAuthorizationFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Failed());
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ErrorResponseModel>>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountSucceeds_ReturnsOk(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Success());
sutProvider.GetDependency<IAdminRecoverAccountCommand>()
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<IAdminRecoverAccountCommand>().Received(1)
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Success());
sutProvider.GetDependency<IAdminRecoverAccountCommand>()
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error message" }));
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ModelStateDictionary>>(result);
}
}

View File

@@ -54,7 +54,7 @@ public class SavePolicyRequestTests
}
[Theory, BitAutoData]
public async Task ToSavePolicyModelAsync_WithNullData_HandlesCorrectly(
public async Task ToSavePolicyModelAsync_WithEmptyData_HandlesCorrectly(
Guid organizationId,
Guid userId)
{
@@ -68,10 +68,8 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.SingleOrg,
Enabled = false,
Data = null
},
Metadata = null
Enabled = false
}
};
// Act
@@ -100,10 +98,8 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.SingleOrg,
Enabled = false,
Data = null
},
Metadata = null
Enabled = false
}
};
// Act
@@ -133,8 +129,7 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
Enabled = true
},
Metadata = new Dictionary<string, object>
{
@@ -152,7 +147,7 @@ public class SavePolicyRequestTests
}
[Theory, BitAutoData]
public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithNullMetadata_ReturnsEmptyMetadata(
public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithEmptyMetadata_ReturnsEmptyMetadata(
Guid organizationId,
Guid userId)
{
@@ -166,10 +161,8 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
},
Metadata = null
Enabled = true
}
};
// Act
@@ -246,8 +239,7 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.MaximumVaultTimeout,
Enabled = true,
Data = null
Enabled = true
},
Metadata = new Dictionary<string, object>
{
@@ -280,8 +272,7 @@ public class SavePolicyRequestTests
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
Enabled = true
},
Metadata = errorDictionary
};

View File

@@ -0,0 +1,150 @@
using Bit.Api.AdminConsole.Models.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Response;
public class ProfileOrganizationResponseModelTests
{
[Theory, BitAutoData]
public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization)
{
var userId = Guid.NewGuid();
var organizationUserId = Guid.NewGuid();
var providerId = Guid.NewGuid();
var organizationIdsClaimingUser = new[] { organization.Id };
var organizationDetails = new OrganizationUserOrganizationDetails
{
OrganizationId = organization.Id,
UserId = userId,
OrganizationUserId = organizationUserId,
Name = organization.Name,
Enabled = organization.Enabled,
Identifier = organization.Identifier,
PlanType = organization.PlanType,
UsePolicies = organization.UsePolicies,
UseSso = organization.UseSso,
UseKeyConnector = organization.UseKeyConnector,
UseScim = organization.UseScim,
UseGroups = organization.UseGroups,
UseDirectory = organization.UseDirectory,
UseEvents = organization.UseEvents,
UseTotp = organization.UseTotp,
Use2fa = organization.Use2fa,
UseApi = organization.UseApi,
UseResetPassword = organization.UseResetPassword,
UseSecretsManager = organization.UseSecretsManager,
UsePasswordManager = organization.UsePasswordManager,
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,
SelfHost = organization.SelfHost,
Seats = organization.Seats,
MaxCollections = organization.MaxCollections,
MaxStorageGb = organization.MaxStorageGb,
Key = "organization-key",
PublicKey = "public-key",
PrivateKey = "private-key",
LimitCollectionCreation = organization.LimitCollectionCreation,
LimitCollectionDeletion = organization.LimitCollectionDeletion,
LimitItemDeletion = organization.LimitItemDeletion,
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems,
ProviderId = providerId,
ProviderName = "Test Provider",
ProviderType = ProviderType.Msp,
SsoEnabled = true,
SsoConfig = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize(),
SsoExternalId = "external-sso-id",
Permissions = CoreHelpers.ClassToJsonData(new Core.Models.Data.Permissions { ManageUsers = true }),
ResetPasswordKey = "reset-password-key",
FamilySponsorshipFriendlyName = "Family Sponsorship",
FamilySponsorshipLastSyncDate = DateTime.UtcNow.AddDays(-1),
FamilySponsorshipToDelete = false,
FamilySponsorshipValidUntil = DateTime.UtcNow.AddYears(1),
IsAdminInitiated = true,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
AccessSecretsManager = true,
SmSeats = 5,
SmServiceAccounts = 10
};
var result = new ProfileOrganizationResponseModel(organizationDetails, organizationIdsClaimingUser);
Assert.Equal("profileOrganization", result.Object);
Assert.Equal(organization.Id, result.Id);
Assert.Equal(userId, result.UserId);
Assert.Equal(organization.Name, result.Name);
Assert.Equal(organization.Enabled, result.Enabled);
Assert.Equal(organization.Identifier, result.Identifier);
Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType);
Assert.Equal(organization.UsePolicies, result.UsePolicies);
Assert.Equal(organization.UseSso, result.UseSso);
Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector);
Assert.Equal(organization.UseScim, result.UseScim);
Assert.Equal(organization.UseGroups, result.UseGroups);
Assert.Equal(organization.UseDirectory, result.UseDirectory);
Assert.Equal(organization.UseEvents, result.UseEvents);
Assert.Equal(organization.UseTotp, result.UseTotp);
Assert.Equal(organization.Use2fa, result.Use2fa);
Assert.Equal(organization.UseApi, result.UseApi);
Assert.Equal(organization.UseResetPassword, result.UseResetPassword);
Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager);
Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager);
Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium);
Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions);
Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
Assert.Equal(organization.SelfHost, result.SelfHost);
Assert.Equal(organization.Seats, result.Seats);
Assert.Equal(organization.MaxCollections, result.MaxCollections);
Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb);
Assert.Equal(organizationDetails.Key, result.Key);
Assert.True(result.HasPublicAndPrivateKeys);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organizationDetails.ProviderId, result.ProviderId);
Assert.Equal(organizationDetails.ProviderName, result.ProviderName);
Assert.Equal(organizationDetails.ProviderType, result.ProviderType);
Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled);
Assert.True(result.KeyConnectorEnabled);
Assert.Equal("https://keyconnector.example.com", result.KeyConnectorUrl);
Assert.Equal(MemberDecryptionType.KeyConnector, result.SsoMemberDecryptionType);
Assert.True(result.SsoBound);
Assert.Equal(organizationDetails.Status, result.Status);
Assert.Equal(organizationDetails.Type, result.Type);
Assert.Equal(organizationDetails.OrganizationUserId, result.OrganizationUserId);
Assert.True(result.UserIsClaimedByOrganization);
Assert.NotNull(result.Permissions);
Assert.True(result.ResetPasswordEnrolled);
Assert.Equal(organizationDetails.AccessSecretsManager, result.AccessSecretsManager);
Assert.Equal(organizationDetails.FamilySponsorshipFriendlyName, result.FamilySponsorshipFriendlyName);
Assert.Equal(organizationDetails.FamilySponsorshipLastSyncDate, result.FamilySponsorshipLastSyncDate);
Assert.Equal(organizationDetails.FamilySponsorshipToDelete, result.FamilySponsorshipToDelete);
Assert.Equal(organizationDetails.FamilySponsorshipValidUntil, result.FamilySponsorshipValidUntil);
Assert.True(result.IsAdminInitiated);
Assert.False(result.FamilySponsorshipAvailable);
}
}

View File

@@ -0,0 +1,129 @@
using Bit.Api.AdminConsole.Models.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Enums;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Response;
public class ProfileProviderOrganizationResponseModelTests
{
[Theory, BitAutoData]
public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization)
{
var userId = Guid.NewGuid();
var providerId = Guid.NewGuid();
var providerUserId = Guid.NewGuid();
var organizationDetails = new ProviderUserOrganizationDetails
{
OrganizationId = organization.Id,
UserId = userId,
Name = organization.Name,
Enabled = organization.Enabled,
Identifier = organization.Identifier,
PlanType = organization.PlanType,
UsePolicies = organization.UsePolicies,
UseSso = organization.UseSso,
UseKeyConnector = organization.UseKeyConnector,
UseScim = organization.UseScim,
UseGroups = organization.UseGroups,
UseDirectory = organization.UseDirectory,
UseEvents = organization.UseEvents,
UseTotp = organization.UseTotp,
Use2fa = organization.Use2fa,
UseApi = organization.UseApi,
UseResetPassword = organization.UseResetPassword,
UseSecretsManager = organization.UseSecretsManager,
UsePasswordManager = organization.UsePasswordManager,
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,
SelfHost = organization.SelfHost,
Seats = organization.Seats,
MaxCollections = organization.MaxCollections,
MaxStorageGb = organization.MaxStorageGb,
Key = "provider-org-key",
PublicKey = "public-key",
PrivateKey = "private-key",
LimitCollectionCreation = organization.LimitCollectionCreation,
LimitCollectionDeletion = organization.LimitCollectionDeletion,
LimitItemDeletion = organization.LimitItemDeletion,
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems,
ProviderId = providerId,
ProviderName = "Test MSP Provider",
ProviderType = ProviderType.Msp,
SsoEnabled = true,
SsoConfig = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption
}.Serialize(),
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin,
ProviderUserId = providerUserId
};
var result = new ProfileProviderOrganizationResponseModel(organizationDetails);
Assert.Equal("profileProviderOrganization", result.Object);
Assert.Equal(organization.Id, result.Id);
Assert.Equal(userId, result.UserId);
Assert.Equal(organization.Name, result.Name);
Assert.Equal(organization.Enabled, result.Enabled);
Assert.Equal(organization.Identifier, result.Identifier);
Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType);
Assert.Equal(organization.UsePolicies, result.UsePolicies);
Assert.Equal(organization.UseSso, result.UseSso);
Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector);
Assert.Equal(organization.UseScim, result.UseScim);
Assert.Equal(organization.UseGroups, result.UseGroups);
Assert.Equal(organization.UseDirectory, result.UseDirectory);
Assert.Equal(organization.UseEvents, result.UseEvents);
Assert.Equal(organization.UseTotp, result.UseTotp);
Assert.Equal(organization.Use2fa, result.Use2fa);
Assert.Equal(organization.UseApi, result.UseApi);
Assert.Equal(organization.UseResetPassword, result.UseResetPassword);
Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager);
Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager);
Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium);
Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions);
Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
Assert.Equal(organization.SelfHost, result.SelfHost);
Assert.Equal(organization.Seats, result.Seats);
Assert.Equal(organization.MaxCollections, result.MaxCollections);
Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb);
Assert.Equal(organizationDetails.Key, result.Key);
Assert.True(result.HasPublicAndPrivateKeys);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organizationDetails.ProviderId, result.ProviderId);
Assert.Equal(organizationDetails.ProviderName, result.ProviderName);
Assert.Equal(organizationDetails.ProviderType, result.ProviderType);
Assert.Equal(OrganizationUserStatusType.Confirmed, result.Status);
Assert.Equal(OrganizationUserType.Owner, result.Type);
Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled);
Assert.False(result.KeyConnectorEnabled);
Assert.Null(result.KeyConnectorUrl);
Assert.Equal(MemberDecryptionType.TrustedDeviceEncryption, result.SsoMemberDecryptionType);
Assert.False(result.SsoBound);
Assert.NotNull(result.Permissions);
Assert.False(result.Permissions.ManageUsers);
Assert.False(result.ResetPasswordEnrolled);
Assert.False(result.AccessSecretsManager);
}
}

View File

@@ -1,4 +1,5 @@
using Bit.Api.Dirt.Controllers;
using Bit.Api.Dirt.Models.Response;
using Bit.Core.Context;
using Bit.Core.Dirt.Entities;
using Bit.Core.Dirt.Models.Data;
@@ -39,7 +40,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -262,7 +264,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -365,7 +368,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -597,7 +601,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -812,7 +817,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -1050,7 +1056,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]

View File

@@ -75,6 +75,7 @@ public class ImportCiphersControllerTests
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.OrganizationId, Guid.NewGuid().ToString())
.With(c => c.FolderId, Guid.NewGuid().ToString())
.With(c => c.ArchivedDate, (DateTime?)null)
.CreateMany(1).ToArray())
.Create();
@@ -92,6 +93,37 @@ public class ImportCiphersControllerTests
);
}
[Theory, BitAutoData]
public async Task PostImportIndividual_WithArchivedDate_SavesArchivedDate(User user,
IFixture fixture, SutProvider<ImportCiphersController> sutProvider)
{
var archivedDate = DateTime.UtcNow;
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = false;
sutProvider.GetDependency<Core.Services.IUserService>()
.GetProperUserId(Arg.Any<ClaimsPrincipal>())
.Returns(user.Id);
var request = fixture.Build<ImportCiphersRequestModel>()
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.ArchivedDate, archivedDate)
.With(c => c.FolderId, (string)null)
.CreateMany(1).ToArray())
.Create();
await sutProvider.Sut.PostImport(request);
await sutProvider.GetDependency<IImportCiphersCommand>()
.Received()
.ImportIntoIndividualVaultAsync(
Arg.Any<List<Folder>>(),
Arg.Is<List<CipherDetails>>(ciphers => ciphers.First().ArchivedDate == archivedDate),
Arg.Any<IEnumerable<KeyValuePair<int, int>>>(),
user.Id
);
}
/****************************
* PostImport - Organization
****************************/
@@ -156,6 +188,7 @@ public class ImportCiphersControllerTests
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.OrganizationId, Guid.NewGuid().ToString())
.With(c => c.FolderId, Guid.NewGuid().ToString())
.With(c => c.ArchivedDate, (DateTime?)null)
.CreateMany(1).ToArray())
.With(y => y.Collections, fixture.Build<CollectionWithIdRequestModel>()
.With(c => c.Id, orgIdGuid)
@@ -227,6 +260,7 @@ public class ImportCiphersControllerTests
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.OrganizationId, Guid.NewGuid().ToString())
.With(c => c.FolderId, Guid.NewGuid().ToString())
.With(c => c.ArchivedDate, (DateTime?)null)
.CreateMany(1).ToArray())
.With(y => y.Collections, fixture.Build<CollectionWithIdRequestModel>()
.With(c => c.Id, orgIdGuid)
@@ -291,6 +325,7 @@ public class ImportCiphersControllerTests
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.OrganizationId, Guid.NewGuid().ToString())
.With(c => c.FolderId, Guid.NewGuid().ToString())
.With(c => c.ArchivedDate, (DateTime?)null)
.CreateMany(1).ToArray())
.With(y => y.Collections, fixture.Build<CollectionWithIdRequestModel>()
.With(c => c.Id, orgIdGuid)
@@ -354,6 +389,7 @@ public class ImportCiphersControllerTests
.With(x => x.Ciphers, fixture.Build<CipherRequestModel>()
.With(c => c.OrganizationId, Guid.NewGuid().ToString())
.With(c => c.FolderId, Guid.NewGuid().ToString())
.With(c => c.ArchivedDate, (DateTime?)null)
.CreateMany(1).ToArray())
.With(y => y.Collections, fixture.Build<CollectionWithIdRequestModel>()
.With(c => c.Id, orgIdGuid)
@@ -423,6 +459,7 @@ public class ImportCiphersControllerTests
Ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.OrganizationId, orgId.ToString())
.With(_ => _.FolderId, Guid.NewGuid().ToString())
.With(_ => _.ArchivedDate, (DateTime?)null)
.CreateMany(2).ToArray(),
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
@@ -499,6 +536,7 @@ public class ImportCiphersControllerTests
Ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.OrganizationId, orgId.ToString())
.With(_ => _.FolderId, Guid.NewGuid().ToString())
.With(_ => _.ArchivedDate, (DateTime?)null)
.CreateMany(2).ToArray(),
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
@@ -578,6 +616,7 @@ public class ImportCiphersControllerTests
Ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.OrganizationId, orgId.ToString())
.With(_ => _.FolderId, Guid.NewGuid().ToString())
.With(_ => _.ArchivedDate, (DateTime?)null)
.CreateMany(2).ToArray(),
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
@@ -651,6 +690,7 @@ public class ImportCiphersControllerTests
Ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.OrganizationId, orgId.ToString())
.With(_ => _.FolderId, Guid.NewGuid().ToString())
.With(_ => _.ArchivedDate, (DateTime?)null)
.CreateMany(2).ToArray(),
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
@@ -720,6 +760,7 @@ public class ImportCiphersControllerTests
Ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.OrganizationId, orgId.ToString())
.With(_ => _.FolderId, Guid.NewGuid().ToString())
.With(_ => _.ArchivedDate, (DateTime?)null)
.CreateMany(2).ToArray(),
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
@@ -765,6 +806,63 @@ public class ImportCiphersControllerTests
Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task PostImportOrganization_ThrowsException_WhenAnyCipherIsArchived(
SutProvider<ImportCiphersController> sutProvider,
IFixture fixture,
User user
)
{
var orgId = Guid.NewGuid();
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = false;
sutProvider.GetDependency<GlobalSettings>()
.ImportCiphersLimitation = _organizationCiphersLimitations;
SetupUserService(sutProvider, user);
var ciphers = fixture.Build<CipherRequestModel>()
.With(_ => _.ArchivedDate, DateTime.UtcNow)
.CreateMany(2).ToArray();
var request = new ImportOrganizationCiphersRequestModel
{
Collections = new List<CollectionWithIdRequestModel>().ToArray(),
Ciphers = ciphers,
CollectionRelationships = new List<KeyValuePair<int, int>>().ToArray(),
};
sutProvider.GetDependency<ICurrentContext>()
.AccessImportExport(Arg.Any<Guid>())
.Returns(false);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
Arg.Any<IEnumerable<Collection>>(),
Arg.Is<IEnumerable<IAuthorizationRequirement>>(reqs =>
reqs.Contains(BulkCollectionOperations.ImportCiphers)))
.Returns(AuthorizationResult.Failed());
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(Arg.Any<ClaimsPrincipal>(),
Arg.Any<IEnumerable<Collection>>(),
Arg.Is<IEnumerable<IAuthorizationRequirement>>(reqs =>
reqs.Contains(BulkCollectionOperations.Create)))
.Returns(AuthorizationResult.Success());
sutProvider.GetDependency<ICollectionRepository>()
.GetManyByOrganizationIdAsync(orgId)
.Returns(new List<Collection>());
var exception = await Assert.ThrowsAsync<BadRequestException>(async () =>
{
await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request);
});
Assert.Equal("You cannot import archived items into an organization.", exception.Message);
}
private static void SetupUserService(SutProvider<ImportCiphersController> sutProvider, User user)
{
// This is a workaround for the NSubstitute issue with ambiguous arguments

View File

@@ -0,0 +1,221 @@
using Bit.Api.Models.Public.Request;
using Bit.Api.Models.Public.Response;
using Bit.Api.Utilities.DiagnosticTools;
using Bit.Core;
using Bit.Core.Models.Data;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.Utilities.DiagnosticTools;
public class EventDiagnosticLoggerTests
{
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_FeatureFlagEnabled_LogsInformation(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var request = new EventFilterRequestModel()
{
Start = DateTime.UtcNow.AddMinutes(-3),
End = DateTime.UtcNow,
ActingUserId = Guid.NewGuid(),
ItemId = Guid.NewGuid(),
};
var newestEvent = Substitute.For<IEvent>();
newestEvent.Date.Returns(DateTime.UtcNow);
var middleEvent = Substitute.For<IEvent>();
middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1));
var oldestEvent = Substitute.For<IEvent>();
oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-3));
var eventResponses = new List<EventResponseModel>
{
new (newestEvent),
new (middleEvent),
new (oldestEvent)
};
var response = new PagedListResponseModel<EventResponseModel>(eventResponses, "continuation-token");
// Act
logger.LogAggregateData(featureService, organizationId, response, request);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains($"Event count:{eventResponses.Count}") &&
o.ToString().Contains($"newest record:{newestEvent.Date:O}") &&
o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") &&
o.ToString().Contains("HasMore:True") &&
o.ToString().Contains($"Start:{request.Start:o}") &&
o.ToString().Contains($"End:{request.End:o}") &&
o.ToString().Contains($"ActingUserId:{request.ActingUserId}") &&
o.ToString().Contains($"ItemId:{request.ItemId}"))
,
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_FeatureFlagDisabled_DoesNotLog(
Guid organizationId,
EventFilterRequestModel request)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false);
PagedListResponseModel<EventResponseModel> dummy = null;
// Act
logger.LogAggregateData(featureService, organizationId, dummy, request);
// Assert
logger.DidNotReceive().Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_EmptyData_LogsZeroCount(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var request = new EventFilterRequestModel()
{
Start = null,
End = null,
ActingUserId = null,
ItemId = null,
ContinuationToken = null,
};
var response = new PagedListResponseModel<EventResponseModel>(new List<EventResponseModel>(), null);
// Act
logger.LogAggregateData(featureService, organizationId, response, request);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains("Event count:0") &&
o.ToString().Contains("HasMore:False")),
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_FeatureFlagDisabled_DoesNotLog(Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false);
// Act
logger.LogAggregateData(featureService, organizationId, null, null, null, null);
// Assert
logger.DidNotReceive().Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_EmptyData_LogsZeroCount(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
Bit.Api.Models.Response.EventResponseModel[] emptyEvents = [];
// Act
logger.LogAggregateData(featureService, organizationId, emptyEvents, null, null, null);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains("Event count:0") &&
o.ToString().Contains("HasMore:False")),
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_FeatureFlagEnabled_LogsInformation(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var newestEvent = Substitute.For<IEvent>();
newestEvent.Date.Returns(DateTime.UtcNow);
var middleEvent = Substitute.For<IEvent>();
middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1));
var oldestEvent = Substitute.For<IEvent>();
oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-2));
var events = new List<Bit.Api.Models.Response.EventResponseModel>
{
new (newestEvent),
new (middleEvent),
new (oldestEvent)
};
var queryStart = DateTime.UtcNow.AddMinutes(-3);
var queryEnd = DateTime.UtcNow;
const string continuationToken = "continuation-token";
// Act
logger.LogAggregateData(featureService, organizationId, events, continuationToken, queryStart, queryEnd);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains($"Event count:{events.Count}") &&
o.ToString().Contains($"newest record:{newestEvent.Date:O}") &&
o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") &&
o.ToString().Contains("HasMore:True") &&
o.ToString().Contains($"Start:{queryStart:o}") &&
o.ToString().Contains($"End:{queryEnd:o}"))
,
null,
Arg.Any<Func<object, Exception, string>>());
}
}

View File

@@ -285,6 +285,10 @@ public class SyncControllerTests
providerUserRepository
.GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails);
foreach (var p in providerUserOrganizationDetails)
{
p.SsoConfig = null;
}
providerUserRepository
.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed)
.Returns(providerUserOrganizationDetails);

View File

@@ -0,0 +1,391 @@
using Bit.Billing.Controllers;
using Bit.Billing.Models;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Clients;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using BitPayLight.Models.Invoice;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
using Transaction = Bit.Core.Entities.Transaction;
namespace Bit.Billing.Test.Controllers;
using static BitPayConstants;
public class BitPayControllerTests
{
private readonly GlobalSettings _globalSettings = new();
private readonly IBitPayClient _bitPayClient = Substitute.For<IBitPayClient>();
private readonly ITransactionRepository _transactionRepository = Substitute.For<ITransactionRepository>();
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IMailService _mailService = Substitute.For<IMailService>();
private readonly IPaymentService _paymentService = Substitute.For<IPaymentService>();
private readonly IPremiumUserBillingService _premiumUserBillingService =
Substitute.For<IPremiumUserBillingService>();
private const string _validWebhookKey = "valid-webhook-key";
private const string _invalidWebhookKey = "invalid-webhook-key";
public BitPayControllerTests()
{
var bitPaySettings = new GlobalSettings.BitPaySettings { WebhookKey = _validWebhookKey };
_globalSettings.BitPay = bitPaySettings;
}
private BitPayController CreateController() => new(
_globalSettings,
_bitPayClient,
_transactionRepository,
_organizationRepository,
_userRepository,
_providerRepository,
_mailService,
_paymentService,
Substitute.For<ILogger<BitPayController>>(),
_premiumUserBillingService);
[Fact]
public async Task PostIpn_InvalidKey_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var result = await controller.PostIpn(eventModel, _invalidWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid key", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NullKey_ThrowsException()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
await Assert.ThrowsAsync<ArgumentNullException>(() => controller.PostIpn(eventModel, null!));
}
[Fact]
public async Task PostIpn_EmptyKey_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var result = await controller.PostIpn(eventModel, string.Empty);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid key", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NonUsdCurrency_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(currency: "EUR");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Cannot process non-USD payments", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NullPosData_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: null!);
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_EmptyPosData_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: "");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_PosDataWithoutAccountCredit_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: "organizationId:550e8400-e29b-41d4-a716-446655440000");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_PosDataWithoutValidId_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: PosDataKeys.AccountCredit);
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_IncompleteInvoice_Ok()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(status: "paid");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal("Waiting for invoice to be completed", okResult.Value);
}
[Fact]
public async Task PostIpn_ExistingTransaction_Ok()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice();
var existingTransaction = new Transaction { GatewayId = invoice.Id };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns(existingTransaction);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal("Invoice already processed", okResult.Value);
}
[Fact]
public async Task PostIpn_ValidOrganizationTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var organizationId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}");
var organization = new Organization { Id = organizationId, BillingEmail = "billing@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
_paymentService.CreditAccountAsync(organization, Arg.Any<decimal>()).Returns(true);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.OrganizationId == organizationId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _organizationRepository.Received(1).ReplaceAsync(organization);
await _mailService.Received(1).SendAddedCreditAsync("billing@example.com", 100.00m);
}
[Fact]
public async Task PostIpn_ValidUserTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var userId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}");
var user = new User { Id = userId, Email = "user@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_userRepository.GetByIdAsync(userId).Returns(user);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.UserId == userId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _premiumUserBillingService.Received(1).Credit(user, 100.00m);
await _mailService.Received(1).SendAddedCreditAsync("user@example.com", 100.00m);
}
[Fact]
public async Task PostIpn_ValidProviderTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var providerId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}");
var provider = new Provider { Id = providerId, BillingEmail = "provider@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_providerRepository.GetByIdAsync(providerId).Returns(Task.FromResult(provider));
_paymentService.CreditAccountAsync(provider, Arg.Any<decimal>()).Returns(true);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.ProviderId == providerId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _providerRepository.Received(1).ReplaceAsync(provider);
await _mailService.Received(1).SendAddedCreditAsync("provider@example.com", 100.00m);
}
[Fact]
public void GetIdsFromPosData_ValidOrganizationId_ReturnsCorrectId()
{
var controller = CreateController();
var organizationId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Equal(organizationId, result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_ValidUserId_ReturnsCorrectId()
{
var controller = CreateController();
var userId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Equal(userId, result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_ValidProviderId_ReturnsCorrectId()
{
var controller = CreateController();
var providerId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Equal(providerId, result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_InvalidGuid_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: "organizationId:invalid-guid,{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_NullPosData_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: null!);
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_EmptyPosData_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: "");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
private static BitPayEventModel CreateValidEventModel(string invoiceId = "test-invoice-id")
{
return new BitPayEventModel
{
Event = new BitPayEventModel.EventModel { Code = 1005, Name = "invoice_confirmed" },
Data = new BitPayEventModel.InvoiceDataModel { Id = invoiceId }
};
}
private static Invoice CreateValidInvoice(string invoiceId = "test-invoice-id", string status = "complete",
string currency = "USD", decimal price = 100.00m,
string posData = "organizationId:550e8400-e29b-41d4-a716-446655440000,accountCredit:1")
{
return new Invoice
{
Id = invoiceId,
Status = status,
Currency = currency,
Price = (double)price,
PosData = posData,
CurrentTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(),
Transactions =
[
new InvoiceTransaction
{
Type = null,
Confirmations = "1",
ReceivedTime = DateTime.UtcNow.ToString("O")
}
]
};
}
}

View File

@@ -0,0 +1,234 @@
using Bit.Billing.Jobs;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Quartz;
using Xunit;
namespace Bit.Billing.Test.Jobs;
public class ProviderOrganizationDisableJobTests
{
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly ILogger<ProviderOrganizationDisableJob> _logger;
private readonly ProviderOrganizationDisableJob _sut;
public ProviderOrganizationDisableJobTests()
{
_providerOrganizationRepository = Substitute.For<IProviderOrganizationRepository>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_logger = Substitute.For<ILogger<ProviderOrganizationDisableJob>>();
_sut = new ProviderOrganizationDisableJob(
_providerOrganizationRepository,
_organizationDisableCommand,
_logger);
}
[Fact]
public async Task Execute_NoOrganizations_LogsAndReturns()
{
// Arrange
var providerId = Guid.NewGuid();
var context = CreateJobExecutionContext(providerId, DateTime.UtcNow);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns((ICollection<ProviderOrganizationOrganizationDetails>)null);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
}
[Fact]
public async Task Execute_WithOrganizations_DisablesAllOrganizations()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var org1Id = Guid.NewGuid();
var org2Id = Guid.NewGuid();
var org3Id = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = org1Id },
new() { OrganizationId = org2Id },
new() { OrganizationId = org3Id }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_WithExpirationDate_PassesDateToDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = new DateTime(2025, 12, 31, 23, 59, 59);
var orgId = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = orgId }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(orgId, expirationDate);
}
[Fact]
public async Task Execute_WithNullExpirationDate_PassesNullToDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var orgId = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = orgId }
};
var context = CreateJobExecutionContext(providerId, null);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(orgId, null);
}
[Fact]
public async Task Execute_OneOrganizationFails_ContinuesProcessingOthers()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var org1Id = Guid.NewGuid();
var org2Id = Guid.NewGuid();
var org3Id = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = org1Id },
new() { OrganizationId = org2Id },
new() { OrganizationId = org3Id }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Make org2 fail
_organizationDisableCommand.DisableAsync(org2Id, Arg.Any<DateTime?>())
.Throws(new Exception("Database error"));
// Act
await _sut.Execute(context);
// Assert - all three should be attempted
await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_ManyOrganizations_ProcessesWithLimitedConcurrency()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
// Create 20 organizations
var organizations = Enumerable.Range(1, 20)
.Select(_ => new ProviderOrganizationOrganizationDetails { OrganizationId = Guid.NewGuid() })
.ToList();
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
var concurrentCalls = 0;
var maxConcurrentCalls = 0;
var lockObj = new object();
_organizationDisableCommand.DisableAsync(Arg.Any<Guid>(), Arg.Any<DateTime?>())
.Returns(callInfo =>
{
lock (lockObj)
{
concurrentCalls++;
if (concurrentCalls > maxConcurrentCalls)
{
maxConcurrentCalls = concurrentCalls;
}
}
return Task.Delay(50).ContinueWith(_ =>
{
lock (lockObj)
{
concurrentCalls--;
}
});
});
// Act
await _sut.Execute(context);
// Assert
Assert.True(maxConcurrentCalls <= 5, $"Expected max concurrency of 5, but got {maxConcurrentCalls}");
await _organizationDisableCommand.Received(20).DisableAsync(Arg.Any<Guid>(), Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_EmptyOrganizationsList_DoesNotCallDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var context = CreateJobExecutionContext(providerId, DateTime.UtcNow);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(new List<ProviderOrganizationOrganizationDetails>());
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
}
private static IJobExecutionContext CreateJobExecutionContext(Guid providerId, DateTime? expirationDate)
{
var context = Substitute.For<IJobExecutionContext>();
var jobDataMap = new JobDataMap
{
{ "providerId", providerId.ToString() },
{ "expirationDate", expirationDate?.ToString("O") }
};
context.MergedJobDataMap.Returns(jobDataMap);
return context;
}
}

View File

@@ -1,10 +1,15 @@
using Bit.Billing.Constants;
using Bit.Billing.Jobs;
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Extensions;
using Bit.Core.Services;
using NSubstitute;
using Quartz;
using Stripe;
using Xunit;
@@ -16,6 +21,10 @@ public class SubscriptionDeletedHandlerTests
private readonly IUserService _userService;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IScheduler _scheduler;
private readonly SubscriptionDeletedHandler _sut;
public SubscriptionDeletedHandlerTests()
@@ -24,11 +33,19 @@ public class SubscriptionDeletedHandlerTests
_userService = Substitute.For<IUserService>();
_stripeEventUtilityService = Substitute.For<IStripeEventUtilityService>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_providerRepository = Substitute.For<IProviderRepository>();
_providerService = Substitute.For<IProviderService>();
_schedulerFactory = Substitute.For<ISchedulerFactory>();
_scheduler = Substitute.For<IScheduler>();
_schedulerFactory.GetScheduler().Returns(_scheduler);
_sut = new SubscriptionDeletedHandler(
_stripeEventService,
_userService,
_stripeEventUtilityService,
_organizationDisableCommand);
_organizationDisableCommand,
_providerRepository,
_providerService,
_schedulerFactory);
}
[Fact]
@@ -59,6 +76,7 @@ public class SubscriptionDeletedHandlerTests
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
await _userService.DidNotReceiveWithAnyArgs().DisablePremiumAsync(default, default);
await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default);
}
[Fact]
@@ -192,4 +210,120 @@ public class SubscriptionDeletedHandlerTests
await _organizationDisableCommand.DidNotReceiveWithAnyArgs()
.DisableAsync(default, default);
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_DisablesProviderAndQueuesJob()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var provider = new Provider
{
Id = providerId,
Enabled = true
};
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await _sut.HandleAsync(stripeEvent);
// Assert
Assert.False(provider.Enabled);
await _providerService.Received(1).UpdateAsync(provider);
await _scheduler.Received(1).ScheduleJob(
Arg.Is<IJobDetail>(j => j.JobType == typeof(ProviderOrganizationDisableJob)),
Arg.Any<ITrigger>());
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_ProviderNotFound_DoesNotThrow()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns((Provider)null);
// Act & Assert - Should not throw
await _sut.HandleAsync(stripeEvent);
// Assert
await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default);
await _scheduler.DidNotReceiveWithAnyArgs().ScheduleJob(default, default);
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_QueuesJobWithCorrectParameters()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var provider = new Provider
{
Id = providerId,
Enabled = true
};
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = expirationDate }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await _sut.HandleAsync(stripeEvent);
// Assert
Assert.False(provider.Enabled);
await _providerService.Received(1).UpdateAsync(provider);
await _scheduler.Received(1).ScheduleJob(
Arg.Is<IJobDetail>(j =>
j.JobType == typeof(ProviderOrganizationDisableJob) &&
j.JobDataMap.GetString("providerId") == providerId.ToString() &&
j.JobDataMap.GetString("expirationDate") == expirationDate.ToString("O")),
Arg.Is<ITrigger>(t => t.Key.Name == $"disable-trigger-{providerId}"));
}
}

View File

@@ -1,7 +1,7 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using MailKit.Security;
using Microsoft.Extensions.Logging;

View File

@@ -1,4 +1,6 @@
using AutoFixture;
using System.Reflection;
using AutoFixture;
using AutoFixture.Xunit2;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
@@ -23,6 +25,7 @@ public class CurrentContextOrganizationCustomization : ICustomization
}
}
[AttributeUsage(AttributeTargets.Method)]
public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribute
{
public Guid Id { get; set; }
@@ -38,3 +41,19 @@ public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribut
AccessSecretsManager = AccessSecretsManager
};
}
public class CurrentContextOrganizationAttribute : CustomizeAttribute
{
public Guid Id { get; set; }
public OrganizationUserType Type { get; set; } = OrganizationUserType.User;
public Permissions Permissions { get; set; } = new();
public bool AccessSecretsManager { get; set; } = false;
public override ICustomization GetCustomization(ParameterInfo _) => new CurrentContextOrganizationCustomization
{
Id = Id,
Type = Type,
Permissions = Permissions,
AccessSecretsManager = AccessSecretsManager
};
}

View File

@@ -8,6 +8,7 @@ namespace Bit.Core.Test.Models.Data.EventIntegrations;
public class IntegrationMessageTests
{
private const string _messageId = "TestMessageId";
private const string _organizationId = "TestOrganizationId";
[Fact]
public void ApplyRetry_IncrementsRetryCountAndSetsDelayUntilDate()
@@ -16,6 +17,7 @@ public class IntegrationMessageTests
{
Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"),
MessageId = _messageId,
OrganizationId = _organizationId,
RetryCount = 2,
RenderedTemplate = string.Empty,
DelayUntilDate = null
@@ -36,6 +38,7 @@ public class IntegrationMessageTests
{
Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"),
MessageId = _messageId,
OrganizationId = _organizationId,
RenderedTemplate = "This is the message",
IntegrationType = IntegrationType.Webhook,
RetryCount = 2,
@@ -48,6 +51,7 @@ public class IntegrationMessageTests
Assert.NotNull(result);
Assert.Equal(message.Configuration, result.Configuration);
Assert.Equal(message.MessageId, result.MessageId);
Assert.Equal(message.OrganizationId, result.OrganizationId);
Assert.Equal(message.RenderedTemplate, result.RenderedTemplate);
Assert.Equal(message.IntegrationType, result.IntegrationType);
Assert.Equal(message.RetryCount, result.RetryCount);
@@ -67,6 +71,7 @@ public class IntegrationMessageTests
var message = new IntegrationMessage
{
MessageId = _messageId,
OrganizationId = _organizationId,
RenderedTemplate = "This is the message",
IntegrationType = IntegrationType.Webhook,
RetryCount = 2,
@@ -77,6 +82,7 @@ public class IntegrationMessageTests
var result = JsonSerializer.Deserialize<IntegrationMessage>(json);
Assert.Equal(message.MessageId, result.MessageId);
Assert.Equal(message.OrganizationId, result.OrganizationId);
Assert.Equal(message.RenderedTemplate, result.RenderedTemplate);
Assert.Equal(message.IntegrationType, result.IntegrationType);
Assert.Equal(message.RetryCount, result.RetryCount);

View File

@@ -0,0 +1,296 @@
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AutoFixture.OrganizationUserFixtures;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.AccountRecovery;
[SutProviderCustomize]
public class AdminRecoverAccountCommandTests
{
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_Success(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
User user,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
SetupValidUser(sutProvider, user, organizationUser);
SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword);
// Act
var result = await sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key);
// Assert
Assert.True(result.Succeeded);
await AssertSuccessAsync(sutProvider, user, key, organization, organizationUser);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_OrganizationDoesNotExist_ThrowsBadRequest(
[OrganizationUser] OrganizationUser organizationUser,
string newMasterPassword,
string key,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
var orgId = Guid.NewGuid();
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgId)
.Returns((Organization)null);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(orgId, organizationUser, newMasterPassword, key));
Assert.Equal("Organization does not allow password reset.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_OrganizationDoesNotAllowResetPassword_ThrowsBadRequest(
string newMasterPassword,
string key,
Organization organization,
[OrganizationUser] OrganizationUser organizationUser,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
organization.UseResetPassword = false;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Organization does not allow password reset.", exception.Message);
}
public static IEnumerable<object[]> InvalidPolicies => new object[][]
{
[new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null]
};
[Theory]
[BitMemberAutoData(nameof(InvalidPolicies))]
public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest(
Policy resetPasswordPolicy,
string newMasterPassword,
string key,
Organization organization,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
.Returns(resetPasswordPolicy);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, new OrganizationUser { Id = Guid.NewGuid() },
newMasterPassword, key));
Assert.Equal("Organization does not have the password reset policy enabled.", exception.Message);
}
public static IEnumerable<object[]> InvalidOrganizationUsers()
{
// Make an organization so we can use its Id
var organization = new Fixture().Create<Organization>();
var nonConfirmed = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = organization.Id,
Status = OrganizationUserStatusType.Invited
};
yield return [nonConfirmed, organization];
var wrongOrganization = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = Guid.NewGuid(), // Different org
ResetPasswordKey = "test-key",
UserId = Guid.NewGuid(),
};
yield return [wrongOrganization, organization];
var nullResetPasswordKey = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = null,
UserId = Guid.NewGuid(),
};
yield return [nullResetPasswordKey, organization];
var emptyResetPasswordKey = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = "",
UserId = Guid.NewGuid(),
};
yield return [emptyResetPasswordKey, organization];
var nullUserId = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = "test-key",
UserId = null,
};
yield return [nullUserId, organization];
}
[Theory]
[BitMemberAutoData(nameof(InvalidOrganizationUsers))]
public async Task RecoverAccountAsync_OrganizationUserIsInvalid_ThrowsBadRequest(
OrganizationUser organizationUser,
Organization organization,
string newMasterPassword,
string key,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Organization User not valid", exception.Message);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_UserDoesNotExist_ThrowsNotFoundException(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(organizationUser.UserId!.Value)
.Returns((User)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_UserUsesKeyConnector_ThrowsBadRequest(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
User user,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
user.UsesKeyConnector = true;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(organizationUser.UserId!.Value)
.Returns(user);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Cannot reset password of a user with Key Connector.", exception.Message);
}
private static void SetupValidOrganization(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization)
{
organization.UseResetPassword = true;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
}
private static void SetupValidPolicy(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization)
{
var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true };
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
.Returns(policy);
}
private static void SetupValidOrganizationUser(OrganizationUser organizationUser, Guid orgId)
{
organizationUser.Status = OrganizationUserStatusType.Confirmed;
organizationUser.OrganizationId = orgId;
organizationUser.ResetPasswordKey = "test-key";
organizationUser.Type = OrganizationUserType.User;
}
private static void SetupValidUser(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, OrganizationUser organizationUser)
{
user.Id = organizationUser.UserId!.Value;
user.UsesKeyConnector = false;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(user.Id)
.Returns(user);
}
private static void SetupSuccessfulPasswordUpdate(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, string newMasterPassword)
{
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(user, newMasterPassword)
.Returns(IdentityResult.Success);
}
private static async Task AssertSuccessAsync(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, string key,
Organization organization, OrganizationUser organizationUser)
{
await sutProvider.GetDependency<IUserRepository>().Received(1).ReplaceAsync(
Arg.Is<User>(u =>
u.Id == user.Id &&
u.Key == key &&
u.ForcePasswordReset == true &&
u.RevisionDate == u.AccountRevisionDate &&
u.LastPasswordChangeDate == u.RevisionDate));
await sutProvider.GetDependency<IMailService>().Received(1).SendAdminResetPasswordEmailAsync(
Arg.Is(user.Email),
Arg.Is(user.Name),
Arg.Is(organization.DisplayName()));
await sutProvider.GetDependency<IEventService>().Received(1).LogOrganizationUserEventAsync(
Arg.Is(organizationUser),
Arg.Is(EventType.OrganizationUser_AdminResetPassword));
await sutProvider.GetDependency<IPushNotificationService>().Received(1).PushLogOutAsync(
Arg.Is(user.Id));
}
}

View File

@@ -14,10 +14,12 @@ public class PolicyRequirementQueryTests
[Theory, BitAutoData]
public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId)
{
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.RequireSso };
var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId };
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserId(userId).Returns([otherPolicy, thisPolicy]);
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([otherPolicy, thisPolicy]);
var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
@@ -33,9 +35,11 @@ public class PolicyRequirementQueryTests
{
// Arrange policies
var policyRepository = Substitute.For<IPolicyRepository>();
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
policyRepository.GetPolicyDetailsByUserId(userId).Returns([thisPolicy, otherPolicy]);
var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([thisPolicy, otherPolicy]);
// Arrange a substitute Enforce function so that we can inspect the received calls
var callback = Substitute.For<Func<PolicyDetails, bool>>();
@@ -70,7 +74,9 @@ public class PolicyRequirementQueryTests
public async Task GetAsync_HandlesNoPolicies(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserId(userId).Returns([]);
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([]);
var factory = new TestPolicyRequirementFactory(x => x.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);

View File

@@ -0,0 +1,28 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class UriMatchDefaultPolicyValidatorTests
{
private readonly UriMatchDefaultPolicyValidator _validator = new();
[Fact]
// Test that the Type property returns the correct PolicyType for this validator
public void Type_ReturnsUriMatchDefaults()
{
Assert.Equal(PolicyType.UriMatchDefaults, _validator.Type);
}
[Fact]
// Test that the RequiredPolicies property returns exactly one policy (SingleOrg) as a prerequisite
// for enabling the UriMatchDefaults policy, ensuring proper policy dependency enforcement
public void RequiredPolicies_ReturnsSingleOrgPolicy()
{
var requiredPolicies = _validator.RequiredPolicies.ToList();
Assert.Single(requiredPolicies);
Assert.Contains(PolicyType.SingleOrg, requiredPolicies);
}
}

View File

@@ -22,18 +22,20 @@ public class EventIntegrationEventWriteServiceTests
[Theory, BitAutoData]
public async Task CreateAsync_EventPublishedToEventQueue(EventMessage eventMessage)
{
var expected = JsonSerializer.Serialize(eventMessage);
await Subject.CreateAsync(eventMessage);
await _eventIntegrationPublisher.Received(1).PublishEventAsync(
Arg.Is<string>(body => AssertJsonStringsMatch(eventMessage, body)));
body: Arg.Is<string>(body => AssertJsonStringsMatch(eventMessage, body)),
organizationId: Arg.Is<string>(orgId => eventMessage.OrganizationId.ToString().Equals(orgId)));
}
[Theory, BitAutoData]
public async Task CreateManyAsync_EventsPublishedToEventQueue(IEnumerable<EventMessage> eventMessages)
{
var eventMessage = eventMessages.First();
await Subject.CreateManyAsync(eventMessages);
await _eventIntegrationPublisher.Received(1).PublishEventAsync(
Arg.Is<string>(body => AssertJsonStringsMatch(eventMessages, body)));
body: Arg.Is<string>(body => AssertJsonStringsMatch(eventMessages, body)),
organizationId: Arg.Is<string>(orgId => eventMessage.OrganizationId.ToString().Equals(orgId)));
}
private static bool AssertJsonStringsMatch(EventMessage expected, string body)

View File

@@ -23,6 +23,7 @@ public class EventIntegrationHandlerTests
private const string _templateWithOrganization = "Org: #OrganizationName#";
private const string _templateWithUser = "#UserName#, #UserEmail#";
private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#";
private static readonly Guid _organizationId = Guid.NewGuid();
private static readonly Uri _uri = new Uri("https://localhost");
private static readonly Uri _uri2 = new Uri("https://example.com");
private readonly IEventIntegrationPublisher _eventIntegrationPublisher = Substitute.For<IEventIntegrationPublisher>();
@@ -50,6 +51,7 @@ public class EventIntegrationHandlerTests
{
IntegrationType = IntegrationType.Webhook,
MessageId = "TestMessageId",
OrganizationId = _organizationId.ToString(),
Configuration = new WebhookIntegrationConfigurationDetails(_uri),
RenderedTemplate = template,
RetryCount = 0,
@@ -122,6 +124,7 @@ public class EventIntegrationHandlerTests
public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage)
{
var sutProvider = GetSutProvider(OneConfiguration(_templateBase));
eventMessage.OrganizationId = _organizationId;
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -140,6 +143,7 @@ public class EventIntegrationHandlerTests
public async Task HandleEventAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages(EventMessage eventMessage)
{
var sutProvider = GetSutProvider(TwoConfigurations(_templateBase));
eventMessage.OrganizationId = _organizationId;
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -164,6 +168,7 @@ public class EventIntegrationHandlerTests
var user = Substitute.For<User>();
user.Email = "test@example.com";
user.Name = "Test";
eventMessage.OrganizationId = _organizationId;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(Arg.Any<Guid>()).Returns(user);
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -183,6 +188,7 @@ public class EventIntegrationHandlerTests
var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization));
var organization = Substitute.For<Organization>();
organization.Name = "Test";
eventMessage.OrganizationId = _organizationId;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(Arg.Any<Guid>()).Returns(organization);
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -205,6 +211,7 @@ public class EventIntegrationHandlerTests
var user = Substitute.For<User>();
user.Email = "test@example.com";
user.Name = "Test";
eventMessage.OrganizationId = _organizationId;
sutProvider.GetDependency<IUserRepository>().GetByIdAsync(Arg.Any<Guid>()).Returns(user);
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -235,6 +242,7 @@ public class EventIntegrationHandlerTests
var sutProvider = GetSutProvider(ValidFilterConfiguration());
sutProvider.GetDependency<IIntegrationFilterService>().EvaluateFilterGroup(
Arg.Any<IntegrationFilterGroup>(), Arg.Any<EventMessage>()).Returns(true);
eventMessage.OrganizationId = _organizationId;
await sutProvider.Sut.HandleEventAsync(eventMessage);
@@ -284,7 +292,7 @@ public class EventIntegrationHandlerTests
$"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}"
);
await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(
AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" })));
AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId", "OrganizationId" })));
}
}
@@ -301,12 +309,12 @@ public class EventIntegrationHandlerTests
var expectedMessage = EventIntegrationHandlerTests.expectedMessage(
$"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}"
);
await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(
AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" })));
await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual(
expectedMessage, new[] { "MessageId", "OrganizationId" })));
expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2);
await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(
AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" })));
await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual(
expectedMessage, new[] { "MessageId", "OrganizationId" })));
}
}
}

View File

@@ -16,6 +16,7 @@ public class IntegrationHandlerTests
{
Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"),
MessageId = "TestMessageId",
OrganizationId = "TestOrganizationId",
IntegrationType = IntegrationType.Webhook,
RenderedTemplate = "Template",
DelayUntilDate = null,
@@ -25,6 +26,8 @@ public class IntegrationHandlerTests
var result = await sut.HandleAsync(expected.ToJson());
var typedResult = Assert.IsType<IntegrationMessage<WebhookIntegrationConfigurationDetails>>(result.Message);
Assert.Equal(expected.MessageId, typedResult.MessageId);
Assert.Equal(expected.OrganizationId, typedResult.OrganizationId);
Assert.Equal(expected.Configuration, typedResult.Configuration);
Assert.Equal(expected.RenderedTemplate, typedResult.RenderedTemplate);
Assert.Equal(expected.IntegrationType, typedResult.IntegrationType);

View File

@@ -0,0 +1,59 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.Utilities;
using Bit.Core.Exceptions;
using Xunit;
namespace Bit.Core.Test.AdminConsole.Utilities;
public class PolicyDataValidatorTests
{
[Fact]
public void ValidateAndSerialize_NullData_ReturnsNull()
{
var result = PolicyDataValidator.ValidateAndSerialize(null, PolicyType.MasterPassword);
Assert.Null(result);
}
[Fact]
public void ValidateAndSerialize_ValidData_ReturnsSerializedJson()
{
var data = new Dictionary<string, object> { { "minLength", 12 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minLength\":12", result);
}
[Fact]
public void ValidateAndSerialize_InvalidDataType_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minLength", "not a number" } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
Assert.Contains("minLength", exception.Message);
}
[Fact]
public void ValidateAndDeserializeMetadata_NullMetadata_ReturnsEmptyMetadataModel()
{
var result = PolicyDataValidator.ValidateAndDeserializeMetadata(null, PolicyType.SingleOrg);
Assert.IsType<EmptyMetadataModel>(result);
}
[Fact]
public void ValidateAndDeserializeMetadata_ValidMetadata_ReturnsModel()
{
var metadata = new Dictionary<string, object> { { "defaultUserCollectionName", "collection name" } };
var result = PolicyDataValidator.ValidateAndDeserializeMetadata(metadata, PolicyType.OrganizationDataOwnership);
Assert.IsType<OrganizationModelOwnershipPolicyModel>(result);
}
}

View File

@@ -1,5 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Clients;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Entities;
@@ -11,12 +12,18 @@ using Invoice = BitPayLight.Models.Invoice.Invoice;
namespace Bit.Core.Test.Billing.Payment.Commands;
using static BitPayConstants;
public class CreateBitPayInvoiceForCreditCommandTests
{
private readonly IBitPayClient _bitPayClient = Substitute.For<IBitPayClient>();
private readonly GlobalSettings _globalSettings = new()
{
BitPay = new GlobalSettings.BitPaySettings { NotificationUrl = "https://example.com/bitpay/notification" }
BitPay = new GlobalSettings.BitPaySettings
{
NotificationUrl = "https://example.com/bitpay/notification",
WebhookKey = "test-webhook-key"
}
};
private const string _redirectUrl = "https://bitwarden.com/redirect";
private readonly CreateBitPayInvoiceForCreditCommand _command;
@@ -37,8 +44,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == user.Email &&
options.Buyer.Name == user.Email &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"userId:{user.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"userId:{user.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
@@ -58,8 +65,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == organization.BillingEmail &&
options.Buyer.Name == organization.Name &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"organizationId:{organization.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"organizationId:{organization.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
@@ -79,8 +86,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == provider.BillingEmail &&
options.Buyer.Name == provider.Name &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"providerId:{provider.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"providerId:{provider.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });

View File

@@ -0,0 +1,112 @@
using System.Text.Json;
using Bit.Core.Billing.Payment.Models;
using Xunit;
namespace Bit.Core.Test.Billing.Payment.Models;
public class PaymentMethodTests
{
[Theory]
[InlineData("{\"cardNumber\":\"1234\"}")]
[InlineData("{\"type\":\"unknown_type\",\"data\":\"value\"}")]
[InlineData("{\"type\":\"invalid\",\"token\":\"test-token\"}")]
[InlineData("{\"type\":\"invalid\"}")]
public void Read_ShouldThrowJsonException_OnInvalidOrMissingType(string json)
{
// Arrange
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act & Assert
Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<PaymentMethod>(json, options));
}
[Theory]
[InlineData("{\"type\":\"card\"}")]
[InlineData("{\"type\":\"card\",\"token\":\"\"}")]
[InlineData("{\"type\":\"card\",\"token\":null}")]
public void Read_ShouldThrowJsonException_OnInvalidTokenizedPaymentMethodToken(string json)
{
// Arrange
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act & Assert
Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<PaymentMethod>(json, options));
}
// Tokenized payment method deserialization
[Theory]
[InlineData("bankAccount", TokenizablePaymentMethodType.BankAccount)]
[InlineData("card", TokenizablePaymentMethodType.Card)]
[InlineData("payPal", TokenizablePaymentMethodType.PayPal)]
public void Read_ShouldDeserializeTokenizedPaymentMethods(string typeString, TokenizablePaymentMethodType expectedType)
{
// Arrange
var json = $"{{\"type\":\"{typeString}\",\"token\":\"test-token\"}}";
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act
var result = JsonSerializer.Deserialize<PaymentMethod>(json, options);
// Assert
Assert.True(result.IsTokenized);
Assert.Equal(expectedType, result.AsT0.Type);
Assert.Equal("test-token", result.AsT0.Token);
}
// Non-tokenized payment method deserialization
[Theory]
[InlineData("accountcredit", NonTokenizablePaymentMethodType.AccountCredit)]
public void Read_ShouldDeserializeNonTokenizedPaymentMethods(string typeString, NonTokenizablePaymentMethodType expectedType)
{
// Arrange
var json = $"{{\"type\":\"{typeString}\"}}";
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act
var result = JsonSerializer.Deserialize<PaymentMethod>(json, options);
// Assert
Assert.True(result.IsNonTokenized);
Assert.Equal(expectedType, result.AsT1.Type);
}
// Tokenized payment method serialization
[Theory]
[InlineData(TokenizablePaymentMethodType.BankAccount, "bankaccount")]
[InlineData(TokenizablePaymentMethodType.Card, "card")]
[InlineData(TokenizablePaymentMethodType.PayPal, "paypal")]
public void Write_ShouldSerializeTokenizedPaymentMethods(TokenizablePaymentMethodType type, string expectedTypeString)
{
// Arrange
var paymentMethod = new PaymentMethod(new TokenizedPaymentMethod
{
Type = type,
Token = "test-token"
});
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act
var json = JsonSerializer.Serialize(paymentMethod, options);
// Assert
Assert.Contains($"\"type\":\"{expectedTypeString}\"", json);
Assert.Contains("\"token\":\"test-token\"", json);
}
// Non-tokenized payment method serialization
[Theory]
[InlineData(NonTokenizablePaymentMethodType.AccountCredit, "accountcredit")]
public void Write_ShouldSerializeNonTokenizedPaymentMethods(NonTokenizablePaymentMethodType type, string expectedTypeString)
{
// Arrange
var paymentMethod = new PaymentMethod(new NonTokenizedPaymentMethod { Type = type });
var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } };
// Act
var json = JsonSerializer.Serialize(paymentMethod, options);
// Assert
Assert.Contains($"\"type\":\"{expectedTypeString}\"", json);
Assert.DoesNotContain("token", json);
}
}

View File

@@ -1,7 +1,12 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
@@ -14,6 +19,8 @@ using NSubstitute;
using Stripe;
using Xunit;
using Address = Stripe.Address;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
using StripeCustomer = Stripe.Customer;
using StripeSubscription = Stripe.Subscription;
@@ -28,6 +35,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPushNotificationService _pushNotificationService = Substitute.For<IPushNotificationService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IHasPaymentMethodQuery _hasPaymentMethodQuery = Substitute.For<IHasPaymentMethodQuery>();
private readonly IUpdatePaymentMethodCommand _updatePaymentMethodCommand = Substitute.For<IUpdatePaymentMethodCommand>();
private readonly CreatePremiumCloudHostedSubscriptionCommand _command;
public CreatePremiumCloudHostedSubscriptionCommandTests()
@@ -36,6 +46,17 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
baseServiceUri.CloudRegion.Returns("US");
_globalSettings.BaseServiceUri.Returns(baseServiceUri);
// Setup default premium plan with standard pricing
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new PremiumPurchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually },
Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal }
};
_pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan);
_command = new CreatePremiumCloudHostedSubscriptionCommand(
_braintreeGateway,
_globalSettings,
@@ -44,7 +65,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
_subscriberService,
_userService,
_pushNotificationService,
Substitute.For<ILogger<CreatePremiumCloudHostedSubscriptionCommand>>());
Substitute.For<ILogger<CreatePremiumCloudHostedSubscriptionCommand>>(),
_pricingClient,
_hasPaymentMethodQuery,
_updatePaymentMethodCommand);
}
[Theory, BitAutoData]
@@ -296,7 +320,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
}
[Theory, BitAutoData]
public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer(
public async Task Run_UserHasExistingGatewayCustomerIdAndPaymentMethod_UsesExistingCustomer(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
@@ -329,6 +353,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockInvoice = Substitute.For<Invoice>();
// Mock that the user has a payment method (this is the key difference from the credit purchase case)
_hasPaymentMethodQuery.Run(Arg.Any<User>()).Returns(true);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
@@ -340,6 +366,75 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any<User>(), Arg.Any<TokenizedPaymentMethod>(), Arg.Any<BillingAddress>());
}
[Theory, BitAutoData]
public async Task Run_UserPreviouslyPurchasedCreditWithoutPaymentMethod_UpdatesPaymentMethodAndCreatesSubscription(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = "existing_customer_123"; // Customer exists from previous credit purchase
paymentMethod.Type = TokenizablePaymentMethodType.Card;
paymentMethod.Token = "card_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "existing_customer_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem
{
CurrentPeriodEnd = DateTime.UtcNow.AddDays(30)
}
]
};
var mockInvoice = Substitute.For<Invoice>();
MaskedPaymentMethod mockMaskedPaymentMethod = new MaskedCard
{
Brand = "visa",
Last4 = "1234",
Expiration = "12/2025"
};
// Mock that the user does NOT have a payment method (simulating credit purchase scenario)
_hasPaymentMethodQuery.Run(Arg.Any<User>()).Returns(false);
_updatePaymentMethodCommand.Run(Arg.Any<User>(), Arg.Any<TokenizedPaymentMethod>(), Arg.Any<BillingAddress>())
.Returns(mockMaskedPaymentMethod);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
// Verify that update payment method was called (new behavior for credit purchase case)
await _updatePaymentMethodCommand.Received(1).Run(user, paymentMethod, billingAddress);
// Verify GetCustomerOrThrow was called after updating payment method
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
// Verify no new customer was created
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
// Verify subscription was created
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
// Verify user was updated correctly
Assert.True(user.Premium);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]
@@ -550,4 +645,79 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var unhandled = result.AsT3;
Assert.Equal("Something went wrong with your request. Please contact support for assistance.", unhandled.Response);
}
[Theory, BitAutoData]
public async Task Run_AccountCredit_WithExistingCustomer_Success(
User user,
NonTokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = "existing_customer_123";
paymentMethod.Type = NonTokenizablePaymentMethodType.AccountCredit;
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "existing_customer_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem
{
CurrentPeriodEnd = DateTime.UtcNow.AddDays(30)
}
]
};
var mockInvoice = Substitute.For<Invoice>();
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
Assert.True(user.Premium);
Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate);
}
[Theory, BitAutoData]
public async Task Run_NonTokenizedPaymentWithoutExistingCustomer_ThrowsBillingException(
User user,
NonTokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
// No existing gateway customer ID
user.GatewayCustomerId = null;
paymentMethod.Type = NonTokenizablePaymentMethodType.AccountCredit;
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
//Assert
Assert.True(result.IsT3); // Assuming T3 is the Unhandled result
Assert.IsType<BillingException>(result.AsT3.Exception);
// Verify no customer was created or subscription attempted
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.DidNotReceive().SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _userService.DidNotReceive().SaveUserAsync(Arg.Any<User>());
}
}

View File

@@ -1,23 +1,38 @@
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using static Bit.Core.Billing.Constants.StripeConstants;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class PreviewPremiumTaxCommandTests
{
private readonly ILogger<PreviewPremiumTaxCommand> _logger = Substitute.For<ILogger<PreviewPremiumTaxCommand>>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly PreviewPremiumTaxCommand _command;
public PreviewPremiumTaxCommandTests()
{
_command = new PreviewPremiumTaxCommand(_logger, _stripeAdapter);
// Setup default premium plan with standard pricing
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
LegacyYear = null,
Seat = new PremiumPurchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually },
Storage = new PremiumPurchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal }
};
_pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan);
_command = new PreviewPremiumTaxCommand(_logger, _pricingClient, _stripeAdapter);
}
[Fact]

View File

@@ -28,6 +28,9 @@
<None Remove="Utilities\data\embeddedResource.txt" />
</ItemGroup>
<ItemGroup>
<!-- Email templates uses .hbs extension, they must be included for emails to work -->
<EmbeddedResource Include="**\*.hbs" />
<EmbeddedResource Include="Utilities\data\embeddedResource.txt" />
</ItemGroup>
</Project>

View File

@@ -278,21 +278,27 @@ public class UpdateSecretsManagerSubscriptionCommandTests
SutProvider<UpdateSecretsManagerSubscriptionCommand> sutProvider)
{
// Arrange
const int seatCount = 10;
var existingSeatCount = 9;
// Make sure Password Manager seats is greater or equal to Secrets Manager seats
organization.Seats = seatCount;
const int initialSeatCount = 9;
const int maxSeatCount = 20;
// This represents the total number of users allowed in the organization.
organization.Seats = maxSeatCount;
// This represents the number of Secrets Manager users allowed in the organization.
organization.SmSeats = initialSeatCount;
// This represents the upper limit of Secrets Manager seats that can be automatically scaled.
organization.MaxAutoscaleSmSeats = maxSeatCount;
organization.PlanType = PlanType.EnterpriseAnnually;
var plan = StaticStore.GetPlan(organization.PlanType);
var update = new SecretsManagerSubscriptionUpdate(organization, plan, false)
{
SmSeats = seatCount,
MaxAutoscaleSmSeats = seatCount
SmSeats = 8,
MaxAutoscaleSmSeats = maxSeatCount
};
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id)
.Returns(existingSeatCount);
.Returns(5);
// Act
await sutProvider.Sut.UpdateSubscriptionAsync(update);
@@ -316,21 +322,29 @@ public class UpdateSecretsManagerSubscriptionCommandTests
SutProvider<UpdateSecretsManagerSubscriptionCommand> sutProvider)
{
// Arrange
const int seatCount = 10;
const int existingSeatCount = 10;
var ownerDetailsList = new List<OrganizationUserUserDetails> { new() { Email = "owner@example.com" } };
const int initialSeatCount = 5;
const int maxSeatCount = 10;
// The amount of seats for users in an organization
// This represents the total number of users allowed in the organization.
organization.Seats = maxSeatCount;
// This represents the number of Secrets Manager users allowed in the organization.
organization.SmSeats = initialSeatCount;
// This represents the upper limit of Secrets Manager seats that can be automatically scaled.
organization.MaxAutoscaleSmSeats = maxSeatCount;
var ownerDetailsList = new List<OrganizationUserUserDetails> { new() { Email = "owner@example.com" } };
organization.PlanType = PlanType.EnterpriseAnnually;
var plan = StaticStore.GetPlan(organization.PlanType);
var update = new SecretsManagerSubscriptionUpdate(organization, plan, false)
{
SmSeats = seatCount,
MaxAutoscaleSmSeats = seatCount
SmSeats = maxSeatCount,
MaxAutoscaleSmSeats = maxSeatCount
};
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id)
.Returns(existingSeatCount);
.Returns(maxSeatCount);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByMinimumRoleAsync(organization.Id, OrganizationUserType.Owner)
.Returns(ownerDetailsList);
@@ -340,15 +354,14 @@ public class UpdateSecretsManagerSubscriptionCommandTests
// Assert
// Currently being called once each for different validation methods
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(2)
.Received(1)
.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id);
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendSecretsManagerMaxSeatLimitReachedEmailAsync(Arg.Is(organization),
Arg.Is(seatCount),
Arg.Is(maxSeatCount),
Arg.Is<IEnumerable<string>>(emails => emails.Contains(ownerDetailsList[0].Email)));
}

View File

@@ -0,0 +1,20 @@
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Test.Platform.Mailer.TestMail;
using Xunit;
namespace Bit.Core.Test.Platform.Mailer;
public class HandlebarMailRendererTests
{
[Fact]
public async Task RenderAsync_ReturnsExpectedHtmlAndTxt()
{
var renderer = new HandlebarMailRenderer();
var view = new TestMailView { Name = "John Smith" };
var (html, txt) = await renderer.RenderAsync(view);
Assert.Equal("Hello <b>John Smith</b>", html.Trim());
Assert.Equal("Hello John Smith", txt.Trim());
}
}

View File

@@ -0,0 +1,36 @@
using Bit.Core.Models.Mail;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Test.Platform.Mailer.TestMail;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Platform.Mailer;
public class MailerTest
{
[Fact]
public async Task SendEmailAsync()
{
var deliveryService = Substitute.For<IMailDeliveryService>();
var mailer = new Core.Platform.Mail.Mailer.Mailer(new HandlebarMailRenderer(), deliveryService);
var mail = new TestMail.TestMail()
{
ToEmails = ["test@bw.com"],
View = new TestMailView() { Name = "John Smith" }
};
MailMessage? sentMessage = null;
await deliveryService.SendEmailAsync(Arg.Do<MailMessage>(message =>
sentMessage = message
));
await mailer.SendEmail(mail);
Assert.NotNull(sentMessage);
Assert.Contains("test@bw.com", sentMessage.ToEmails);
Assert.Equal("Test Email", sentMessage.Subject);
Assert.Equivalent("Hello John Smith", sentMessage.TextContent.Trim());
Assert.Equivalent("Hello <b>John Smith</b>", sentMessage.HtmlContent.Trim());
}
}

View File

@@ -0,0 +1,13 @@
using Bit.Core.Platform.Mail.Mailer;
namespace Bit.Core.Test.Platform.Mailer.TestMail;
public class TestMailView : BaseMailView
{
public required string Name { get; init; }
}
public class TestMail : BaseMail<TestMailView>
{
public override string Subject { get; } = "Test Email";
}

View File

@@ -0,0 +1 @@
Hello <b>{{ Name }}</b>

View File

@@ -0,0 +1 @@
Hello {{ Name }}

View File

@@ -1,7 +1,7 @@
using Amazon.SimpleEmail;
using Amazon.SimpleEmail.Model;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;

View File

@@ -6,7 +6,10 @@ using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Business;
using Bit.Core.Entities;
using Bit.Core.Models.Mail;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Enqueuing;
using Bit.Core.Services;
using Bit.Core.Services.Mail;
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;

View File

@@ -1,4 +1,4 @@
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using NSubstitute;

View File

@@ -1,5 +1,5 @@
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;

View File

@@ -113,6 +113,242 @@ public class CipherServiceTests
await sutProvider.GetDependency<ICipherRepository>().Received(1).ReplaceAsync(cipherDetails);
}
[Theory, BitAutoData]
public async Task CreateAttachmentAsync_WrongRevisionDate_Throws(SutProvider<CipherService> sutProvider, Cipher cipher, Guid savingUserId)
{
var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1);
var stream = new MemoryStream();
var fileName = "test.txt";
var key = "test-key";
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, lastKnownRevisionDate));
Assert.Contains("out of date", exception.Message);
}
[Theory]
[BitAutoData("")]
[BitAutoData("Correct Time")]
public async Task CreateAttachmentAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString,
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate;
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Mock user storage and premium access
var user = new User { Id = savingUserId, MaxStorageGb = 1 };
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, Arg.Any<CipherAttachment.MetaData>(), Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.ReplaceAsync(Arg.Any<CipherDetails>())
.Returns(Task.CompletedTask);
await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, lastKnownRevisionDate);
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(Arg.Any<Stream>(), cipher, Arg.Any<CipherAttachment.MetaData>());
}
[Theory, BitAutoData]
public async Task CreateAttachmentForDelayedUploadAsync_WrongRevisionDate_Throws(SutProvider<CipherService> sutProvider, Cipher cipher, Guid savingUserId)
{
var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1);
var key = "test-key";
var fileName = "test.txt";
var fileSize = 100L;
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateAttachmentForDelayedUploadAsync(cipher, key, fileName, fileSize, false, savingUserId, lastKnownRevisionDate));
Assert.Contains("out of date", exception.Message);
}
[Theory]
[BitAutoData("")]
[BitAutoData("Correct Time")]
public async Task CreateAttachmentForDelayedUploadAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString,
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid savingUserId)
{
var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate;
var key = "test-key";
var fileName = "test.txt";
var fileSize = 100L;
// Setup cipher with user ownership
cipher.UserId = savingUserId;
cipher.OrganizationId = null;
// Mock user storage and premium access
var user = new User { Id = savingUserId, MaxStorageGb = 1 };
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(savingUserId)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<IAttachmentStorageService>()
.GetAttachmentUploadUrlAsync(cipher, Arg.Any<CipherAttachment.MetaData>())
.Returns("https://example.com/upload");
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
var result = await sutProvider.Sut.CreateAttachmentForDelayedUploadAsync(cipher, key, fileName, fileSize, false, savingUserId, lastKnownRevisionDate);
Assert.NotNull(result.attachmentId);
Assert.NotNull(result.uploadUrl);
}
[Theory, BitAutoData]
public async Task UploadFileForExistingAttachmentAsync_WrongRevisionDate_Throws(SutProvider<CipherService> sutProvider,
Cipher cipher)
{
var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1);
var stream = new MemoryStream();
var attachment = new CipherAttachment.MetaData
{
AttachmentId = "test-attachment-id",
Size = 100,
FileName = "test.txt",
Key = "test-key"
};
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UploadFileForExistingAttachmentAsync(stream, cipher, attachment, lastKnownRevisionDate));
Assert.Contains("out of date", exception.Message);
}
[Theory]
[BitAutoData("")]
[BitAutoData("Correct Time")]
public async Task UploadFileForExistingAttachmentAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString,
SutProvider<CipherService> sutProvider, CipherDetails cipher)
{
var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate;
var stream = new MemoryStream(new byte[100]);
var attachmentId = "test-attachment-id";
var attachment = new CipherAttachment.MetaData
{
AttachmentId = attachmentId,
Size = 100,
FileName = "test.txt",
Key = "test-key"
};
// Set the attachment on the cipher so ValidateCipherAttachmentFile can find it
cipher.SetAttachments(new Dictionary<string, CipherAttachment.MetaData>
{
[attachmentId] = attachment
});
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadNewAttachmentAsync(stream, cipher, attachment)
.Returns(Task.CompletedTask);
sutProvider.GetDependency<IAttachmentStorageService>()
.ValidateFileAsync(cipher, attachment, Arg.Any<long>())
.Returns((true, 100L));
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
await sutProvider.Sut.UploadFileForExistingAttachmentAsync(stream, cipher, attachment, lastKnownRevisionDate);
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadNewAttachmentAsync(stream, cipher, attachment);
}
[Theory, BitAutoData]
public async Task CreateAttachmentShareAsync_WrongRevisionDate_Throws(SutProvider<CipherService> sutProvider,
Cipher cipher, Guid organizationId)
{
var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1);
var stream = new MemoryStream();
var fileName = "test.txt";
var key = "test-key";
var attachmentId = "attachment-id";
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateAttachmentShareAsync(cipher, stream, fileName, key, 100, attachmentId, organizationId, lastKnownRevisionDate));
Assert.Contains("out of date", exception.Message);
}
[Theory]
[BitAutoData("")]
[BitAutoData("Correct Time")]
public async Task CreateAttachmentShareAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString,
SutProvider<CipherService> sutProvider, CipherDetails cipher, Guid organizationId)
{
var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate;
var stream = new MemoryStream(new byte[100]);
var fileName = "test.txt";
var key = "test-key";
var attachmentId = "attachment-id";
// Setup cipher with existing attachment (no TempMetadata)
cipher.OrganizationId = null;
cipher.SetAttachments(new Dictionary<string, CipherAttachment.MetaData>
{
[attachmentId] = new CipherAttachment.MetaData
{
AttachmentId = attachmentId,
Size = 100,
FileName = "existing.txt",
Key = "existing-key"
}
});
// Mock organization
var organization = new Organization
{
Id = organizationId,
MaxStorageGb = 1
};
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organizationId)
.Returns(organization);
sutProvider.GetDependency<IAttachmentStorageService>()
.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, Arg.Any<CipherAttachment.MetaData>())
.Returns(Task.CompletedTask);
sutProvider.GetDependency<ICipherRepository>()
.UpdateAttachmentAsync(Arg.Any<CipherAttachment>())
.Returns(Task.CompletedTask);
await sutProvider.Sut.CreateAttachmentShareAsync(cipher, stream, fileName, key, 100, attachmentId, organizationId, lastKnownRevisionDate);
await sutProvider.GetDependency<IAttachmentStorageService>().Received(1)
.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, Arg.Any<CipherAttachment.MetaData>());
}
[Theory]
[BitAutoData]
public async Task SaveDetailsAsync_PersonalVault_WithOrganizationDataOwnershipPolicyEnabled_Throws(

View File

@@ -1,21 +1,63 @@
using System.Net.Http.Json;
using System.Net;
using System.Net.Http.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums;
using Bit.Core.Vault.Repositories;
using Bit.Events.Models;
namespace Bit.Events.IntegrationTest.Controllers;
public class CollectControllerTests
public class CollectControllerTests : IAsyncLifetime
{
// This is a very simple test, and should be updated to assert more things, but for now
// it ensures that the events startup doesn't throw any errors with fairly basic configuration.
[Fact]
public async Task Post_Works()
{
var eventsApplicationFactory = new EventsApplicationFactory();
var (accessToken, _) = await eventsApplicationFactory.LoginWithNewAccount();
var client = eventsApplicationFactory.CreateAuthedClient(accessToken);
private EventsApplicationFactory _factory = null!;
private HttpClient _client = null!;
private string _ownerEmail = null!;
private Guid _ownerId;
var response = await client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
public async Task InitializeAsync()
{
_factory = new EventsApplicationFactory();
_ownerEmail = $"integration-test+{Guid.NewGuid()}@bitwarden.com";
var (accessToken, _) = await _factory.LoginWithNewAccount(_ownerEmail);
_client = _factory.CreateAuthedClient(accessToken);
// Get the user ID
var userRepository = _factory.GetService<IUserRepository>();
var user = await userRepository.GetByEmailAsync(_ownerEmail);
_ownerId = user!.Id;
}
public Task DisposeAsync()
{
_client?.Dispose();
_factory?.Dispose();
return Task.CompletedTask;
}
[Fact]
public async Task Post_NullModel_ReturnsBadRequest()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>?>("collect", null);
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Post_EmptyModel_ReturnsBadRequest()
{
var response = await _client.PostAsJsonAsync("collect", Array.Empty<EventModel>());
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Post_UserClientExportedVault_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
@@ -26,4 +68,425 @@ public class CollectControllerTests
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientAutofilled_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedPassword_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedHiddenField_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedHiddenField,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedCardCode_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedCardCode,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledCardNumberVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledCardNumberVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledCardCodeVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledCardCodeVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledHiddenFieldVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledHiddenFieldVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledPasswordVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledPasswordVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientViewed_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEvent_WithoutCipherId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEvent_WithInvalidCipherId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = Guid.NewGuid(),
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithValidOrganization_Success()
{
var organization = await CreateOrganizationAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = organization.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithoutOrganizationId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithInvalidOrganizationId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = Guid.NewGuid(),
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_MultipleEvents_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var organization = await CreateOrganizationAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = organization.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEventsBatch_MoreThan50Items_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
// Create 60 cipher events to test batching logic (should be processed in 2 batches of 50)
var events = Enumerable.Range(0, 60)
.Select(_ => new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
})
.ToList();
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect", events);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_UnsupportedEventType_Success()
{
// Testing with an event type not explicitly handled in the switch statement
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_LoggedIn,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_MixedValidAndInvalidEvents_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = Guid.NewGuid(), // Invalid cipher ID
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id, // Valid cipher ID
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherCaching_MultipleEventsForSameCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
// Multiple events for the same cipher should use caching
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
private async Task<Cipher> CreateCipherForUserAsync(Guid userId)
{
var cipherRepository = _factory.GetService<ICipherRepository>();
var cipher = new Cipher
{
Type = CipherType.Login,
UserId = userId,
Data = "{\"name\":\"Test Cipher\"}",
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await cipherRepository.CreateAsync(cipher);
return cipher;
}
private async Task<Organization> CreateOrganizationAsync(Guid ownerId)
{
var organizationRepository = _factory.GetService<IOrganizationRepository>();
var organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
var organization = new Organization
{
Name = "Test Organization",
BillingEmail = _ownerEmail,
Plan = "Free",
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await organizationRepository.CreateAsync(organization);
// Add the user as an owner of the organization
var organizationUser = new OrganizationUser
{
OrganizationId = organization.Id,
UserId = ownerId,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
};
await organizationUserRepository.CreateAsync(organizationUser);
return organization;
}
}

View File

@@ -0,0 +1,715 @@
using AutoFixture.Xunit2;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories;
using Bit.Events.Controllers;
using Bit.Events.Models;
using Microsoft.AspNetCore.Mvc;
using NSubstitute;
namespace Events.Test.Controllers;
public class CollectControllerTests
{
private readonly CollectController _sut;
private readonly ICurrentContext _currentContext;
private readonly IEventService _eventService;
private readonly ICipherRepository _cipherRepository;
private readonly IOrganizationRepository _organizationRepository;
public CollectControllerTests()
{
_currentContext = Substitute.For<ICurrentContext>();
_eventService = Substitute.For<IEventService>();
_cipherRepository = Substitute.For<ICipherRepository>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_sut = new CollectController(
_currentContext,
_eventService,
_cipherRepository,
_organizationRepository
);
}
[Fact]
public async Task Post_NullModel_ReturnsBadRequest()
{
var result = await _sut.Post(null);
Assert.IsType<BadRequestResult>(result);
}
[Fact]
public async Task Post_EmptyModel_ReturnsBadRequest()
{
var result = await _sut.Post(new List<EventModel>());
Assert.IsType<BadRequestResult>(result);
}
[Theory]
[AutoData]
public async Task Post_UserClientExportedVault_LogsUserEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogUserEventAsync(userId, EventType.User_ClientExportedVault, eventDate);
}
[Theory]
[AutoData]
public async Task Post_CipherAutofilled_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.Count() == 1 &&
tuples.First().Item1 == cipherDetails &&
tuples.First().Item2 == EventType.Cipher_ClientAutofilled &&
tuples.First().Item3 == eventDate
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedPassword_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedPassword
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedHiddenField_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedHiddenField,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedHiddenField
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedCardCode_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedCardCode,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedCardCode
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledCardNumberVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledCardNumberVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledCardNumberVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledCardCodeVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledCardCodeVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledCardCodeVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledHiddenFieldVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledHiddenFieldVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledHiddenFieldVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledPasswordVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledPasswordVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledPasswordVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientViewed_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientViewed
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithoutCipherId_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default, default);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_WithoutOrgId_SkipsEvent(Guid userId, Guid cipherId)
{
_currentContext.UserId.Returns(userId);
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(cipherId);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_WithOrgId_ChecksOrgCipher(
Guid userId, Guid cipherId, Guid orgId, Cipher cipher, CurrentContextOrganization org)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
_currentContext.GetOrganization(orgId).Returns(org);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.Received(1).GetByIdAsync(cipherId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item1 == cipher
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_OrgCipherNotFound_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId)
{
_currentContext.UserId.Returns(userId);
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns((CipherDetails?)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.Received(1).GetByIdAsync(cipherId);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_CipherDoesNotBelongToOrg_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId, Guid differentOrgId, Cipher cipher)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = differentOrgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_OrgNotFound_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId, Cipher cipher)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
_currentContext.GetOrganization(orgId).Returns((CurrentContextOrganization)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_MultipleCipherEvents_WithSameCipherId_UsesCachedCipher(
Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipherId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 2)
);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithValidOrg_LogsOrgEvent(
Guid userId, Guid orgId, Organization organization)
{
_currentContext.UserId.Returns(userId);
organization.Id = orgId;
_organizationRepository.GetByIdAsync(orgId).Returns(organization);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.Received(1).GetByIdAsync(orgId);
await _eventService.Received(1).LogOrganizationEventAsync(organization, EventType.Organization_ClientExportedVault, eventDate);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithoutOrgId_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
await _eventService.DidNotReceiveWithAnyArgs().LogOrganizationEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithNullOrg_SkipsEvent(Guid userId, Guid orgId)
{
_currentContext.UserId.Returns(userId);
_organizationRepository.GetByIdAsync(orgId).Returns((Organization)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.Received(1).GetByIdAsync(orgId);
await _eventService.DidNotReceiveWithAnyArgs().LogOrganizationEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_UnsupportedEventType_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_LoggedIn,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogUserEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_MixedEventTypes_ProcessesAllEvents(
Guid userId, Guid cipherId, Guid orgId, CipherDetails cipherDetails, Organization organization)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
organization.Id = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
_organizationRepository.GetByIdAsync(orgId).Returns(organization);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogUserEventAsync(userId, EventType.User_ClientExportedVault, Arg.Any<DateTime?>());
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 1)
);
await _eventService.Received(1).LogOrganizationEventAsync(organization, EventType.Organization_ClientExportedVault, Arg.Any<DateTime?>());
}
[Theory]
[AutoData]
public async Task Post_MoreThan50CipherEvents_LogsInBatches(Guid userId, List<CipherDetails> ciphers)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>();
for (int i = 0; i < 100; i++)
{
var cipher = ciphers[i % ciphers.Count];
_cipherRepository.GetByIdAsync(cipher.Id, userId).Returns(cipher);
events.Add(new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow
});
}
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(2).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 50)
);
}
[Theory]
[AutoData]
public async Task Post_Exactly50CipherEvents_LogsInSingleBatch(Guid userId, List<CipherDetails> ciphers)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>();
for (int i = 0; i < 50; i++)
{
var cipher = ciphers[i % ciphers.Count];
_cipherRepository.GetByIdAsync(cipher.Id, userId).Returns(cipher);
events.Add(new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow
});
}
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 50)
);
}
}

View File

@@ -9,14 +9,18 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="NSubstitute" Version="$(NSubstituteVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio"
Version="$(XUnitRunnerVisualStudioVersion)">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
<PackageReference Include="AutoFixture.Xunit2" Version="$(AutoFixtureXUnit2Version)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Events\Events.csproj" />
<ProjectReference Include="..\..\src\Core\Core.csproj" />
<ProjectReference Include="..\Common\Common.csproj" />
</ItemGroup>
</Project>

View File

@@ -1,10 +0,0 @@
namespace Events.Test;
// Delete this file once you have real tests
public class PlaceholderUnitTest
{
[Fact]
public void Test1()
{
}
}

View File

@@ -1,6 +1,7 @@
using System.Reflection;
using AutoFixture;
using AutoFixture.Xunit2;
using Bit.Identity.IdentityServer;
using Duende.IdentityServer.Validation;
namespace Bit.Identity.Test.AutoFixture;
@@ -8,7 +9,8 @@ namespace Bit.Identity.Test.AutoFixture;
internal class ValidatedTokenRequestCustomization : ICustomization
{
public ValidatedTokenRequestCustomization()
{ }
{
}
public void Customize(IFixture fixture)
{
@@ -22,10 +24,45 @@ internal class ValidatedTokenRequestCustomization : ICustomization
public class ValidatedTokenRequestAttribute : CustomizeAttribute
{
public ValidatedTokenRequestAttribute()
{ }
{
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new ValidatedTokenRequestCustomization();
}
}
internal class CustomValidatorRequestContextCustomization : ICustomization
{
public CustomValidatorRequestContextCustomization()
{
}
/// <summary>
/// Specific context members like <see cref="CustomValidatorRequestContext.RememberMeRequested" />,
/// <see cref="CustomValidatorRequestContext.TwoFactorRecoveryRequested"/>, and
/// <see cref="CustomValidatorRequestContext.SsoRequired" /> should initialize false,
/// and are made truthy in context upon evaluation of a request. Do not allow AutoFixture to eagerly make these
/// truthy; that is the responsibility of the <see cref="Bit.Identity.IdentityServer.RequestValidators.BaseRequestValidator{T}" />
/// </summary>
public void Customize(IFixture fixture)
{
fixture.Customize<CustomValidatorRequestContext>(composer => composer
.With(o => o.RememberMeRequested, false)
.With(o => o.TwoFactorRecoveryRequested, false)
.With(o => o.SsoRequired, false));
}
}
public class CustomValidatorRequestContextAttribute : CustomizeAttribute
{
public CustomValidatorRequestContextAttribute()
{
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new CustomValidatorRequestContextCustomization();
}
}

View File

@@ -100,19 +100,30 @@ public class BaseRequestValidatorTests
_userAccountKeysQuery);
}
private void SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(bool recoveryCodeSupportEnabled)
{
_featureService
.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)
.Returns(recoveryCodeSupportEnabled);
}
/* Logic path
* ValidateAsync -> UpdateFailedAuthDetailsAsync -> _mailService.SendFailedLoginAttemptsEmailAsync
* |-> BuildErrorResultAsync -> _eventService.LogUserEventAsync
* (self hosted) |-> _logger.LogWarning()
* |-> SetErrorResult
*/
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ContextNotValid_SelfHosted_ShouldBuildErrorResult_ShouldLogWarning(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_globalSettings.SelfHosted = true;
_sut.isValid = false;
@@ -122,18 +133,23 @@ public class BaseRequestValidatorTests
// Assert
var logs = _logger.Collector.GetSnapshot(true);
Assert.Contains(logs, l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
Assert.Contains(logs,
l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_DeviceNotValidated_ShouldLogError(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -141,14 +157,15 @@ public class BaseRequestValidatorTests
// 2 -> will result to false with no extra configuration
// 3 -> set two factor to be false
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to fail
requestContext.KnownDevice = false;
tokenRequest.GrantType = "password";
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
@@ -163,13 +180,17 @@ public class BaseRequestValidatorTests
.LogUserEventAsync(context.CustomValidatorRequestContext.User.Id, EventType.User_FailedLogIn);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_DeviceValidated_ShouldSucceed(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -177,12 +198,13 @@ public class BaseRequestValidatorTests
// 2 -> will result to false with no extra configuration
// 3 -> set two factor to be false
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to pass
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
@@ -202,13 +224,17 @@ public class BaseRequestValidatorTests
Assert.False(context.GrantResult.IsError);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ValidatedAuthRequest_ConsumedOnSuccess(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -235,7 +261,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to pass
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// 5 -> not legacy user
@@ -260,13 +287,17 @@ public class BaseRequestValidatorTests
ar.AuthenticationDate.HasValue));
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ValidatedAuthRequest_NotConsumed_When2faRequired(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -302,13 +333,17 @@ public class BaseRequestValidatorTests
await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any<AuthRequest>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_TwoFactorTokenInvalid_ShouldSendFailedTwoFactorEmail(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -345,13 +380,17 @@ public class BaseRequestValidatorTests
Arg.Any<string>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_TwoFactorRememberTokenExpired_ShouldNotSendFailedTwoFactorEmail(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -391,28 +430,34 @@ public class BaseRequestValidatorTests
// Assert
// Verify that the failed 2FA email was NOT sent for remember token expiration
await _mailService.DidNotReceive()
.SendFailedTwoFactorAttemptEmailAsync(Arg.Any<string>(), Arg.Any<TwoFactorProviderType>(), Arg.Any<DateTime>(), Arg.Any<string>());
.SendFailedTwoFactorAttemptEmailAsync(Arg.Any<string>(), Arg.Any<TwoFactorProviderType>(),
Arg.Any<DateTime>(), Arg.Any<string>());
}
// Test grantTypes that require SSO when a user is in an organization that requires it
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_OrgSsoRequiredTrue_ShouldSetSsoResult(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
context.ValidatedTokenRequest.GrantType = grantType;
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -425,16 +470,21 @@ public class BaseRequestValidatorTests
// Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredTrue_ShouldSetSsoResult(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -449,23 +499,28 @@ public class BaseRequestValidatorTests
// Assert
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
}
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -500,24 +555,29 @@ public class BaseRequestValidatorTests
// Test grantTypes where SSO would be required but the user is not in an
// organization that requires it
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_OrgSsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
context.ValidatedTokenRequest.GrantType = grantType;
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -540,20 +600,23 @@ public class BaseRequestValidatorTests
await _userRepository.Received(1).ReplaceAsync(Arg.Any<User>());
Assert.False(context.GrantResult.IsError);
}
// Test the grantTypes where SSO is in progress or not relevant
[Theory]
[BitAutoData("authorization_code")]
[BitAutoData("client_credentials")]
[BitAutoData("authorization_code", true)]
[BitAutoData("authorization_code", false)]
[BitAutoData("client_credentials", true)]
[BitAutoData("client_credentials", false)]
public async Task ValidateAsync_GrantTypes_SsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -577,7 +640,7 @@ public class BaseRequestValidatorTests
// Assert
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
await _eventService.Received(1).LogUserEventAsync(
context.CustomValidatorRequestContext.User.Id, EventType.User_LoggedIn);
await _userRepository.Received(1).ReplaceAsync(Arg.Any<User>());
@@ -588,13 +651,17 @@ public class BaseRequestValidatorTests
/* Logic Path
* ValidateAsync -> UserService.IsLegacyUser -> FailAuthForLegacyUserAsync
*/
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_IsLegacyUser_FailAuthForLegacyUserAsync(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = context.CustomValidatorRequestContext.User;
user.Key = null;
@@ -613,21 +680,27 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var expectedMessage = "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
var expectedMessage =
"Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
Assert.Equal(expectedMessage, errorResponse.Message);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_NoMasterPassword_ShouldSetUserDecryptionOptions(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = false,
@@ -663,19 +736,24 @@ public class BaseRequestValidatorTests
}
[Theory]
[BitAutoData(KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(KdfType.Argon2id, 11, 128, 5)]
[BitAutoData(true, KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(false, KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(true, KdfType.Argon2id, 11, 128, 5)]
[BitAutoData(false, KdfType.Argon2id, 11, 128, 5)]
public async Task ValidateAsync_CustomResponse_MasterPassword_ShouldSetUserDecryptionOptions(
bool featureFlagValue,
KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = true,
@@ -728,13 +806,17 @@ public class BaseRequestValidatorTests
Assert.Equal("test@example.com", userDecryptionOptions.MasterPasswordUnlock.Salt);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_ShouldIncludeAccountKeys(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var mockAccountKeys = new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -747,11 +829,7 @@ public class BaseRequestValidatorTests
"test-wrapped-signing-key",
"test-verifying-key"
),
SecurityStateData = new SecurityStateData
{
SecurityState = "test-security-state",
SecurityVersion = 2
}
SecurityStateData = new SecurityStateData { SecurityState = "test-security-state", SecurityVersion = 2 }
};
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(mockAccountKeys);
@@ -759,7 +837,8 @@ public class BaseRequestValidatorTests
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = true,
@@ -808,13 +887,18 @@ public class BaseRequestValidatorTests
Assert.Equal("test-security-state", accountKeysResponse.SecurityState.SecurityState);
Assert.Equal(2, accountKeysResponse.SecurityState.SecurityVersion);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_AccountKeysQuery_SkippedWhenPrivateKeyIsNull(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
requestContext.User.PrivateKey = null;
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -833,13 +917,18 @@ public class BaseRequestValidatorTests
// Verify that the account keys query wasn't called.
await _userAccountKeysQuery.Received(0).Run(Arg.Any<User>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_AccountKeysQuery_CalledWithCorrectUser(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var expectedUser = requestContext.User;
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
@@ -853,7 +942,8 @@ public class BaseRequestValidatorTests
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions()));
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -874,6 +964,285 @@ public class BaseRequestValidatorTests
await _userAccountKeysQuery.Received(1).Run(Arg.Is<User>(u => u.Id == expectedUser.Id));
}
/// <summary>
/// Tests the core PM-21153 feature: SSO-required users can use recovery codes to disable 2FA,
/// but must then authenticate via SSO with a descriptive message about the recovery.
/// This test validates:
/// 1. Validation order is changed (2FA before SSO) when recovery code is provided
/// 2. Recovery code successfully validates and sets TwoFactorRecoveryRequested flag
/// 3. SSO validation then fails with recovery-specific message
/// 4. User is NOT logged in (must authenticate via IdP)
/// </summary>
[Theory]
[BitAutoData(true)] // Feature flag ON - new behavior
[BitAutoData(false)] // Feature flag OFF - should fail at SSO before 2FA recovery
public async Task ValidateAsync_RecoveryCodeForSsoRequiredUser_BlocksWithDescriptiveMessage(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// Reset state that AutoFixture may have populated
requestContext.TwoFactorRecoveryRequested = false;
requestContext.RememberMeRequested = false;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is required (this user is in an org that requires SSO)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE (this triggers the special validation order)
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-12345";
// 5. Recovery code is valid (UserService.RecoverTwoFactorAsync will be called internally)
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-12345")
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
if (featureFlagEnabled)
{
// NEW BEHAVIOR: Recovery succeeds, then SSO blocks with descriptive message
Assert.Equal(
"Two-factor recovery has been performed. SSO authentication is required.",
errorResponse.Message);
// Verify recovery was marked
Assert.True(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested flag should be set");
}
else
{
// LEGACY BEHAVIOR: SSO blocks BEFORE recovery can happen
Assert.Equal(
"SSO authentication is required.",
errorResponse.Message);
// Recovery never happened because SSO checked first
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false (SSO blocked first)");
}
// In both cases: User is NOT logged in
await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn);
}
/// <summary>
/// Tests that validation order changes when a recovery code is PROVIDED (even if invalid).
/// This ensures the RecoveryCodeRequestForSsoRequiredUserScenario() logic is based on
/// request structure, not validation outcome. An SSO-required user who provides an
/// INVALID recovery code should:
/// 1. Have 2FA validated BEFORE SSO (new order)
/// 2. Get a 2FA error (invalid token)
/// 3. NOT get the recovery-specific SSO message (because recovery didn't complete)
/// 4. NOT be logged in
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_InvalidRecoveryCodeForSsoRequiredUser_FailsAt2FA(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is required
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE (triggers validation order change)
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "INVALID-recovery-code";
// 5. Recovery code is INVALID
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "INVALID-recovery-code")
.Returns(Task.FromResult(false));
// 6. Setup for failed 2FA email (if feature flag enabled)
_featureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail).Returns(true);
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
if (featureFlagEnabled)
{
// NEW BEHAVIOR: 2FA is checked first (due to recovery code request), fails with 2FA error
Assert.Equal(
"Two-step token is invalid. Try again.",
errorResponse.Message);
// Recovery was attempted but failed - flag should NOT be set
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false (recovery failed)");
// Verify failed 2FA email was sent
await _mailService.Received(1).SendFailedTwoFactorAttemptEmailAsync(
user.Email,
TwoFactorProviderType.RecoveryCode,
Arg.Any<DateTime>(),
Arg.Any<string>());
// Verify failed login event was logged
await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_FailedLogIn2fa);
}
else
{
// LEGACY BEHAVIOR: SSO is checked first, blocks before 2FA
Assert.Equal(
"SSO authentication is required.",
errorResponse.Message);
// 2FA validation never happened
await _mailService.DidNotReceive().SendFailedTwoFactorAttemptEmailAsync(
Arg.Any<string>(),
Arg.Any<TwoFactorProviderType>(),
Arg.Any<DateTime>(),
Arg.Any<string>());
}
// In both cases: User is NOT logged in
await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn);
// Verify user failed login count was updated (in new behavior path)
if (featureFlagEnabled)
{
await _userRepository.Received(1).ReplaceAsync(Arg.Is<User>(u =>
u.Id == user.Id && u.FailedLoginCount > 0));
}
}
/// <summary>
/// Tests that non-SSO users can successfully use recovery codes to disable 2FA and log in.
/// This validates:
/// 1. Validation order changes to 2FA-first when recovery code is provided
/// 2. Recovery code validates successfully
/// 3. SSO check passes (user not in SSO-required org)
/// 4. User successfully logs in
/// 5. TwoFactorRecoveryRequested flag is set (for logging/audit purposes)
/// This is the "happy path" for recovery code usage.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RecoveryCodeForNonSsoUser_SuccessfulLogin(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is NOT required (this is a regular user, not in SSO org)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-67890";
// 5. Recovery code is valid
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-67890")
.Returns(Task.FromResult(true));
// 6. Device validation passes
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7. User is not legacy
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 8. Setup user account keys for successful login response
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
"test-private-key",
"test-public-key"
)
});
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.False(context.GrantResult.IsError, "Authentication should succeed for non-SSO user with valid recovery code");
// Verify user successfully logged in
await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_LoggedIn);
// Verify failed login count was reset (successful login)
await _userRepository.Received(1).ReplaceAsync(Arg.Is<User>(u =>
u.Id == user.Id && u.FailedLoginCount == 0));
if (featureFlagEnabled)
{
// NEW BEHAVIOR: Recovery flag should be set for audit purposes
Assert.True(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested flag should be set for audit/logging");
}
else
{
// LEGACY BEHAVIOR: Recovery flag doesn't exist, but login still succeeds
// (SSO check happens before 2FA in legacy, but user is not SSO-required so both pass)
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false in legacy mode");
}
}
private BaseRequestValidationContextFake CreateContext(
ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,

View File

@@ -265,9 +265,10 @@ public class SendEmailOtpRequestValidatorTests
// Arrange
var otpTokenProvider = Substitute.For<IOtpTokenProvider<DefaultOtpTokenProviderOptions>>();
var mailService = Substitute.For<IMailService>();
var featureService = Substitute.For<IFeatureService>();
// Act
var validator = new SendEmailOtpRequestValidator(otpTokenProvider, mailService);
var validator = new SendEmailOtpRequestValidator(featureService, otpTokenProvider, mailService);
// Assert
Assert.NotNull(validator);

View File

@@ -1,6 +1,7 @@
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Dirt.Entities;
using Bit.Core.Dirt.Reports.Models.Data;
using Bit.Core.Dirt.Repositories;
using Bit.Core.Repositories;
using Bit.Core.Test.AutoFixture.Attributes;
@@ -49,6 +50,75 @@ public class OrganizationReportRepositoryTests
Assert.True(records.Count == 4);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task CreateAsync_ShouldPersistAllMetricProperties_WhenSet(
List<EntityFramework.Dirt.Repositories.OrganizationReportRepository> suts,
List<EfRepo.OrganizationRepository> efOrganizationRepos,
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange - Create a report with explicit metric values
var fixture = new Fixture();
var organization = fixture.Create<Organization>();
var report = fixture.Build<OrganizationReport>()
.With(x => x.ApplicationCount, 10)
.With(x => x.ApplicationAtRiskCount, 3)
.With(x => x.CriticalApplicationCount, 5)
.With(x => x.CriticalApplicationAtRiskCount, 2)
.With(x => x.MemberCount, 25)
.With(x => x.MemberAtRiskCount, 7)
.With(x => x.CriticalMemberCount, 12)
.With(x => x.CriticalMemberAtRiskCount, 4)
.With(x => x.PasswordCount, 100)
.With(x => x.PasswordAtRiskCount, 15)
.With(x => x.CriticalPasswordCount, 50)
.With(x => x.CriticalPasswordAtRiskCount, 8)
.Create();
var retrievedReports = new List<OrganizationReport>();
// Act & Assert - Test EF repositories
foreach (var sut in suts)
{
var i = suts.IndexOf(sut);
var efOrganization = await efOrganizationRepos[i].CreateAsync(organization);
sut.ClearChangeTracking();
report.OrganizationId = efOrganization.Id;
var createdReport = await sut.CreateAsync(report);
sut.ClearChangeTracking();
var savedReport = await sut.GetByIdAsync(createdReport.Id);
retrievedReports.Add(savedReport);
}
// Act & Assert - Test SQL repository
var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization);
report.OrganizationId = sqlOrganization.Id;
var sqlCreatedReport = await sqlOrganizationReportRepo.CreateAsync(report);
var savedSqlReport = await sqlOrganizationReportRepo.GetByIdAsync(sqlCreatedReport.Id);
retrievedReports.Add(savedSqlReport);
// Assert - Verify all metric properties are persisted correctly across all repositories
Assert.True(retrievedReports.Count == 4);
foreach (var retrievedReport in retrievedReports)
{
Assert.NotNull(retrievedReport);
Assert.Equal(10, retrievedReport.ApplicationCount);
Assert.Equal(3, retrievedReport.ApplicationAtRiskCount);
Assert.Equal(5, retrievedReport.CriticalApplicationCount);
Assert.Equal(2, retrievedReport.CriticalApplicationAtRiskCount);
Assert.Equal(25, retrievedReport.MemberCount);
Assert.Equal(7, retrievedReport.MemberAtRiskCount);
Assert.Equal(12, retrievedReport.CriticalMemberCount);
Assert.Equal(4, retrievedReport.CriticalMemberAtRiskCount);
Assert.Equal(100, retrievedReport.PasswordCount);
Assert.Equal(15, retrievedReport.PasswordAtRiskCount);
Assert.Equal(50, retrievedReport.CriticalPasswordCount);
Assert.Equal(8, retrievedReport.CriticalPasswordAtRiskCount);
}
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task RetrieveByOrganisation_Works(
OrganizationReportRepository sqlOrganizationReportRepo,
@@ -66,6 +136,67 @@ public class OrganizationReportRepositoryTests
Assert.Equal(secondOrg.Id, secondRetrievedReport.OrganizationId);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task UpdateAsync_ShouldUpdateAllMetricProperties_WhenChanged(
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange - Create initial report with specific metric values
var fixture = new Fixture();
var organization = fixture.Create<Organization>();
var org = await sqlOrganizationRepo.CreateAsync(organization);
var report = fixture.Build<OrganizationReport>()
.With(x => x.OrganizationId, org.Id)
.With(x => x.ApplicationCount, 10)
.With(x => x.ApplicationAtRiskCount, 3)
.With(x => x.CriticalApplicationCount, 5)
.With(x => x.CriticalApplicationAtRiskCount, 2)
.With(x => x.MemberCount, 25)
.With(x => x.MemberAtRiskCount, 7)
.With(x => x.CriticalMemberCount, 12)
.With(x => x.CriticalMemberAtRiskCount, 4)
.With(x => x.PasswordCount, 100)
.With(x => x.PasswordAtRiskCount, 15)
.With(x => x.CriticalPasswordCount, 50)
.With(x => x.CriticalPasswordAtRiskCount, 8)
.Create();
var createdReport = await sqlOrganizationReportRepo.CreateAsync(report);
// Act - Update all metric properties with new values
createdReport.ApplicationCount = 20;
createdReport.ApplicationAtRiskCount = 6;
createdReport.CriticalApplicationCount = 10;
createdReport.CriticalApplicationAtRiskCount = 4;
createdReport.MemberCount = 50;
createdReport.MemberAtRiskCount = 14;
createdReport.CriticalMemberCount = 24;
createdReport.CriticalMemberAtRiskCount = 8;
createdReport.PasswordCount = 200;
createdReport.PasswordAtRiskCount = 30;
createdReport.CriticalPasswordCount = 100;
createdReport.CriticalPasswordAtRiskCount = 16;
await sqlOrganizationReportRepo.UpsertAsync(createdReport);
// Assert - Verify all metric properties were updated correctly
var updatedReport = await sqlOrganizationReportRepo.GetByIdAsync(createdReport.Id);
Assert.NotNull(updatedReport);
Assert.Equal(20, updatedReport.ApplicationCount);
Assert.Equal(6, updatedReport.ApplicationAtRiskCount);
Assert.Equal(10, updatedReport.CriticalApplicationCount);
Assert.Equal(4, updatedReport.CriticalApplicationAtRiskCount);
Assert.Equal(50, updatedReport.MemberCount);
Assert.Equal(14, updatedReport.MemberAtRiskCount);
Assert.Equal(24, updatedReport.CriticalMemberCount);
Assert.Equal(8, updatedReport.CriticalMemberAtRiskCount);
Assert.Equal(200, updatedReport.PasswordCount);
Assert.Equal(30, updatedReport.PasswordAtRiskCount);
Assert.Equal(100, updatedReport.CriticalPasswordCount);
Assert.Equal(16, updatedReport.CriticalPasswordAtRiskCount);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task Delete_Works(
List<EntityFramework.Dirt.Repositories.OrganizationReportRepository> suts,
@@ -359,6 +490,49 @@ public class OrganizationReportRepositoryTests
Assert.Null(result);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task UpdateMetricsAsync_ShouldUpdateMetricsCorrectly(
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange
var (org, report) = await CreateOrganizationAndReportAsync(sqlOrganizationRepo, sqlOrganizationReportRepo);
var metrics = new OrganizationReportMetricsData
{
ApplicationCount = 10,
ApplicationAtRiskCount = 2,
CriticalApplicationCount = 5,
CriticalApplicationAtRiskCount = 1,
MemberCount = 20,
MemberAtRiskCount = 4,
CriticalMemberCount = 10,
CriticalMemberAtRiskCount = 2,
PasswordCount = 100,
PasswordAtRiskCount = 15,
CriticalPasswordCount = 50,
CriticalPasswordAtRiskCount = 5
};
// Act
await sqlOrganizationReportRepo.UpdateMetricsAsync(report.Id, metrics);
var updatedReport = await sqlOrganizationReportRepo.GetByIdAsync(report.Id);
// Assert
Assert.Equal(metrics.ApplicationCount, updatedReport.ApplicationCount);
Assert.Equal(metrics.ApplicationAtRiskCount, updatedReport.ApplicationAtRiskCount);
Assert.Equal(metrics.CriticalApplicationCount, updatedReport.CriticalApplicationCount);
Assert.Equal(metrics.CriticalApplicationAtRiskCount, updatedReport.CriticalApplicationAtRiskCount);
Assert.Equal(metrics.MemberCount, updatedReport.MemberCount);
Assert.Equal(metrics.MemberAtRiskCount, updatedReport.MemberAtRiskCount);
Assert.Equal(metrics.CriticalMemberCount, updatedReport.CriticalMemberCount);
Assert.Equal(metrics.CriticalMemberAtRiskCount, updatedReport.CriticalMemberAtRiskCount);
Assert.Equal(metrics.PasswordCount, updatedReport.PasswordCount);
Assert.Equal(metrics.PasswordAtRiskCount, updatedReport.PasswordAtRiskCount);
Assert.Equal(metrics.CriticalPasswordCount, updatedReport.CriticalPasswordCount);
Assert.Equal(metrics.CriticalPasswordAtRiskCount, updatedReport.CriticalPasswordAtRiskCount);
}
private async Task<(Organization, OrganizationReport)> CreateOrganizationAndReportAsync(
IOrganizationRepository orgRepo,
IOrganizationReportRepository orgReportRepo)

View File

@@ -33,14 +33,69 @@ public static class OrganizationTestHelpers
public static Task<Organization> CreateTestOrganizationAsync(this IOrganizationRepository organizationRepository,
int? seatCount = null,
string identifier = "test")
=> organizationRepository.CreateAsync(new Organization
{
var id = Guid.NewGuid();
return organizationRepository.CreateAsync(new Organization
{
Name = $"{identifier}-{Guid.NewGuid()}",
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Enterprise (Annually)", // TODO: EF does not enforce this being NOT NULl
Name = $"{identifier}-{id}",
BillingEmail = $"billing-{id}@example.com",
Plan = "Enterprise (Annually)",
PlanType = PlanType.EnterpriseAnnually,
Seats = seatCount
Identifier = $"{identifier}-{id}",
BusinessName = $"Test Business {id}",
BusinessAddress1 = "123 Test Street",
BusinessAddress2 = "Suite 100",
BusinessAddress3 = "Building A",
BusinessCountry = "US",
BusinessTaxNumber = "123456789",
Seats = seatCount,
MaxCollections = 50,
UsePolicies = true,
UseSso = true,
UseKeyConnector = true,
UseScim = true,
UseGroups = true,
UseDirectory = true,
UseEvents = true,
UseTotp = true,
Use2fa = true,
UseApi = true,
UseResetPassword = true,
UseSecretsManager = true,
UsePasswordManager = true,
SelfHost = false,
UsersGetPremium = true,
UseCustomPermissions = true,
Storage = 1073741824, // 1 GB in bytes
MaxStorageGb = 10,
Gateway = GatewayType.Stripe,
GatewayCustomerId = $"cus_{id}",
GatewaySubscriptionId = $"sub_{id}",
ReferenceData = "{\"test\":\"data\"}",
Enabled = true,
LicenseKey = $"license-{id}",
PublicKey = "test-public-key",
PrivateKey = "test-private-key",
TwoFactorProviders = null,
ExpirationDate = DateTime.UtcNow.AddYears(1),
MaxAutoscaleSeats = 200,
OwnersNotifiedOfAutoscaling = null,
Status = OrganizationStatusType.Managed,
SmSeats = 50,
SmServiceAccounts = 25,
MaxAutoscaleSmSeats = 100,
MaxAutoscaleSmServiceAccounts = 50,
LimitCollectionCreation = true,
LimitCollectionDeletion = true,
LimitItemDeletion = true,
AllowAdminAccessToAllCollectionItems = true,
UseRiskInsights = true,
UseOrganizationDomains = true,
UseAdminSponsoredFamilies = true,
SyncSeats = false,
UseAutomaticUserConfirmation = true
});
}
/// <summary>
/// Creates a confirmed Owner for the specified organization and user.

View File

@@ -0,0 +1,447 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationUserRepository;
public class GetByUserIdWithPolicyDetailsTests
{
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithConfirmedUser_ReturnsPolicy(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
Email = null
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
Data = CoreHelpers.ClassToJsonData(new { Setting = "value" })
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
var policyDetails = result.Single();
Assert.Equal(orgUser.Id, policyDetails.OrganizationUserId);
Assert.Equal(org.Id, policyDetails.OrganizationId);
Assert.Equal(PolicyType.SingleOrg, policyDetails.PolicyType);
Assert.True(policyDetails.PolicyEnabled);
Assert.Equal(OrganizationUserType.User, policyDetails.OrganizationUserType);
Assert.Equal(OrganizationUserStatusType.Confirmed, policyDetails.OrganizationUserStatus);
Assert.False(policyDetails.IsProvider);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithAcceptedUser_ReturnsPolicy(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Accepted,
Type = OrganizationUserType.Admin,
Email = null
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = false, // Note: disabled policy
Type = PolicyType.RequireSso,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.RequireSso);
// Assert
var policyDetails = result.Single();
Assert.Equal(orgUser.Id, policyDetails.OrganizationUserId);
Assert.False(policyDetails.PolicyEnabled); // Should return even if disabled
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithInvitedUser_ReturnsPolicy(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = null, // invited users have null userId
Status = OrganizationUserStatusType.Invited,
Type = OrganizationUserType.User,
Email = user.Email // invited users have matching Email
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.TwoFactorAuthentication,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.TwoFactorAuthentication);
// Assert
var policyDetails = result.Single();
Assert.Equal(orgUser.Id, policyDetails.OrganizationUserId);
Assert.Equal(OrganizationUserStatusType.Invited, policyDetails.OrganizationUserStatus);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithRevokedUser_ReturnsPolicy(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Revoked,
Type = OrganizationUserType.Owner,
Email = null
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
var policyDetails = result.Single();
Assert.Equal(OrganizationUserStatusType.Revoked, policyDetails.OrganizationUserStatus);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithMultipleOrganizations_ReturnsAllMatchingPolicies(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
// Org1 with SingleOrg policy
var org1 = await organizationRepository.CreateAsync(new Organization
{
Name = "Org 1",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser1 = new OrganizationUser
{
OrganizationId = org1.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
};
await organizationUserRepository.CreateAsync(orgUser1);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org1.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Org2 with SingleOrg policy
var org2 = await organizationRepository.CreateAsync(new Organization
{
Name = "Org 2",
BillingEmail = "billing2@example.com",
Plan = "Test",
});
var orgUser2 = new OrganizationUser
{
OrganizationId = org2.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Admin,
};
await organizationUserRepository.CreateAsync(orgUser2);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org2.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Org3 with RequireSso policy (different type - should not be returned)
var org3 = await organizationRepository.CreateAsync(new Organization
{
Name = "Org 3",
BillingEmail = "billing3@example.com",
Plan = "Test",
});
var orgUser3 = new OrganizationUser
{
OrganizationId = org3.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
};
await organizationUserRepository.CreateAsync(orgUser3);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org3.Id,
Enabled = true,
Type = PolicyType.RequireSso,
});
// Act
var result = (await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg)).ToList();
// Assert - should only get 2 policies (org1 and org2)
Assert.Equal(2, result.Count);
Assert.Contains(result, p => p.OrganizationId == org1.Id && p.OrganizationUserType == OrganizationUserType.User);
Assert.Contains(result, p => p.OrganizationId == org2.Id && p.OrganizationUserType == OrganizationUserType.Admin);
Assert.DoesNotContain(result, p => p.OrganizationId == org3.Id);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithNonExistingPolicyType_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.RequireSso);
// Assert
Assert.Empty(result);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithProviderUser_ReturnsIsProviderTrue(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
var provider = await providerRepository.CreateAsync(new Provider
{
Name = Guid.NewGuid().ToString(),
Enabled = true
});
await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user.Id,
Status = ProviderUserStatusType.Confirmed
});
await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
OrganizationId = org.Id,
ProviderId = provider.Id
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
var policyDetails = result.Single();
Assert.True(policyDetails.IsProvider);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WithCustomUserWithPermissions_ReturnsPermissions(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Custom,
Email = null
};
orgUser.SetPermissions(new Permissions
{
ManagePolicies = true,
EditAnyCollection = true
});
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
var policyDetails = result.Single();
Assert.NotNull(policyDetails.OrganizationUserPermissionsData);
var permissions = CoreHelpers.LoadClassFromJsonData<Permissions>(policyDetails.OrganizationUserPermissionsData);
Assert.True(permissions.ManagePolicies);
Assert.True(permissions.EditAnyCollection);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WhenNoPolicyExists_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
Assert.Empty(result);
}
[Theory, DatabaseData]
public async Task GetByUserIdWithPolicyDetailsAsync_WhenUserNotInOrg_ReturnsEmpty(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = "billing@example.com",
Plan = "Test",
});
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var result = await organizationUserRepository.GetByUserIdWithPolicyDetailsAsync(user.Id, PolicyType.SingleOrg);
// Assert
Assert.Empty(result);
}
}

View File

@@ -461,13 +461,7 @@ public class OrganizationUserRepositoryTests
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
@@ -536,9 +530,72 @@ public class OrganizationUserRepositoryTests
Assert.Equal(organization.SmServiceAccounts, result.SmServiceAccounts);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
}
[Theory, DatabaseData]
public async Task GetManyDetailsByUserAsync_ShouldPopulateSsoPropertiesCorrectly(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ISsoConfigRepository ssoConfigRepository)
{
var user = await userRepository.CreateTestUserAsync();
var organizationWithSso = await organizationRepository.CreateTestOrganizationAsync();
var organizationWithoutSso = await organizationRepository.CreateTestOrganizationAsync();
var orgUserWithSso = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organizationWithSso.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
Email = user.Email
});
var orgUserWithoutSso = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organizationWithoutSso.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
Email = user.Email
});
// Create SSO configuration for first organization only
var serializedSsoConfigData = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize();
var ssoConfig = await ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organizationWithSso.Id,
Enabled = true,
Data = serializedSsoConfigData
});
var results = (await organizationUserRepository.GetManyDetailsByUserAsync(user.Id)).ToList();
Assert.Equal(2, results.Count);
var orgWithSsoDetails = results.Single(r => r.OrganizationId == organizationWithSso.Id);
var orgWithoutSsoDetails = results.Single(r => r.OrganizationId == organizationWithoutSso.Id);
// Organization with SSO should have SSO properties populated
Assert.True(orgWithSsoDetails.SsoEnabled);
Assert.NotNull(orgWithSsoDetails.SsoConfig);
Assert.Equal(serializedSsoConfigData, orgWithSsoDetails.SsoConfig);
// Organization without SSO should have null SSO properties
Assert.Null(orgWithoutSsoDetails.SsoEnabled);
Assert.Null(orgWithoutSsoDetails.SsoConfig);
}
[DatabaseTheory, DatabaseData]

View File

@@ -1,385 +0,0 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.PolicyRepository;
public class GetPolicyDetailsByUserIdTests
{
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_NonInvitedUsers_Works(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
// OrgUser1 - owner of org1 - confirmed
var user = await userRepository.CreateTestUserAsync();
var org1 = await CreateEnterpriseOrg(organizationRepository);
var orgUser1 = new OrganizationUser
{
OrganizationId = org1.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
Email = null // confirmed OrgUsers use the email on the User table
};
await organizationUserRepository.CreateAsync(orgUser1);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org1.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
Data = CoreHelpers.ClassToJsonData(new TestPolicyData { BoolSetting = true, IntSetting = 5 })
});
// OrgUser2 - custom user of org2 - accepted
var org2 = await CreateEnterpriseOrg(organizationRepository);
var orgUser2 = new OrganizationUser
{
OrganizationId = org2.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Accepted,
Type = OrganizationUserType.Custom,
Email = null // accepted OrgUsers use the email on the User table
};
orgUser2.SetPermissions(new Permissions
{
ManagePolicies = true
});
await organizationUserRepository.CreateAsync(orgUser2);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org2.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
Data = CoreHelpers.ClassToJsonData(new TestPolicyData { BoolSetting = false, IntSetting = 15 })
});
// Act
var policyDetails = (await policyRepository.GetPolicyDetailsByUserId(user.Id)).ToList();
// Assert
Assert.Equal(2, policyDetails.Count);
var actualPolicyDetails1 = policyDetails.Find(p => p.OrganizationUserId == orgUser1.Id);
var expectedPolicyDetails1 = new PolicyDetails
{
OrganizationUserId = orgUser1.Id,
OrganizationId = org1.Id,
PolicyType = PolicyType.SingleOrg,
PolicyData = CoreHelpers.ClassToJsonData(new TestPolicyData { BoolSetting = true, IntSetting = 5 }),
OrganizationUserType = OrganizationUserType.Owner,
OrganizationUserStatus = OrganizationUserStatusType.Confirmed,
OrganizationUserPermissionsData = null,
IsProvider = false
};
Assert.Equivalent(expectedPolicyDetails1, actualPolicyDetails1);
Assert.Equivalent(expectedPolicyDetails1.GetDataModel<TestPolicyData>(), new TestPolicyData { BoolSetting = true, IntSetting = 5 });
var actualPolicyDetails2 = policyDetails.Find(p => p.OrganizationUserId == orgUser2.Id);
var expectedPolicyDetails2 = new PolicyDetails
{
OrganizationUserId = orgUser2.Id,
OrganizationId = org2.Id,
PolicyType = PolicyType.SingleOrg,
PolicyData = CoreHelpers.ClassToJsonData(new TestPolicyData { BoolSetting = false, IntSetting = 15 }),
OrganizationUserType = OrganizationUserType.Custom,
OrganizationUserStatus = OrganizationUserStatusType.Accepted,
OrganizationUserPermissionsData = CoreHelpers.ClassToJsonData(new Permissions { ManagePolicies = true }),
IsProvider = false
};
Assert.Equivalent(expectedPolicyDetails2, actualPolicyDetails2);
Assert.Equivalent(expectedPolicyDetails2.GetDataModel<TestPolicyData>(), new TestPolicyData { BoolSetting = false, IntSetting = 15 });
Assert.Equivalent(new Permissions { ManagePolicies = true }, actualPolicyDetails2.GetOrganizationUserCustomPermissions(), strict: true);
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_InvitedUser_Works(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = null, // invited users have null userId
Status = OrganizationUserStatusType.Invited,
Type = OrganizationUserType.Custom,
Email = user.Email // invited users have matching Email
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
var expectedPolicyDetails = new PolicyDetails
{
OrganizationUserId = orgUser.Id,
OrganizationId = org.Id,
PolicyType = PolicyType.SingleOrg,
OrganizationUserType = OrganizationUserType.Custom,
OrganizationUserStatus = OrganizationUserStatusType.Invited,
IsProvider = false
};
Assert.Equivalent(expectedPolicyDetails, actualPolicyDetails.Single());
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_RevokedConfirmedUser_Works(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
// User has been confirmed to the org but then revoked
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Revoked,
Type = OrganizationUserType.Owner,
Email = null
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
var expectedPolicyDetails = new PolicyDetails
{
OrganizationUserId = orgUser.Id,
OrganizationId = org.Id,
PolicyType = PolicyType.SingleOrg,
OrganizationUserType = OrganizationUserType.Owner,
OrganizationUserStatus = OrganizationUserStatusType.Revoked,
IsProvider = false
};
Assert.Equivalent(expectedPolicyDetails, actualPolicyDetails.Single());
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_RevokedInvitedUser_DoesntReturnPolicies(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
// User has been invited to the org but then revoked - without ever being confirmed and linked to a user.
// This is an unhandled edge case because those users will go through policy enforcement later,
// as part of accepting their invite after being restored. For now this is just documented as expected behavior.
var orgUser = new OrganizationUser
{
OrganizationId = org.Id,
UserId = null,
Status = OrganizationUserStatusType.Revoked,
Type = OrganizationUserType.Owner,
Email = user.Email
};
await organizationUserRepository.CreateAsync(orgUser);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
Assert.Empty(actualPolicyDetails);
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_SetsIsProvider(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Arrange provider
var provider = await providerRepository.CreateAsync(new Provider
{
Name = Guid.NewGuid().ToString(),
Enabled = true
});
await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user.Id,
Status = ProviderUserStatusType.Confirmed
});
await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
OrganizationId = org.Id,
ProviderId = provider.Id
});
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
var expectedPolicyDetails = new PolicyDetails
{
OrganizationUserId = orgUser.Id,
OrganizationId = org.Id,
PolicyType = PolicyType.SingleOrg,
OrganizationUserType = OrganizationUserType.Owner,
OrganizationUserStatus = OrganizationUserStatusType.Confirmed,
IsProvider = true
};
Assert.Equivalent(expectedPolicyDetails, actualPolicyDetails.Single());
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_IgnoresDisabledOrganizations(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Org is disabled; its policies remain, but it is now inactive
org.Enabled = false;
await organizationRepository.ReplaceAsync(org);
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
Assert.Empty(actualPolicyDetails);
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_IgnoresDowngradedOrganizations(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = true,
Type = PolicyType.SingleOrg,
});
// Org is downgraded; its policies remain but its plan no longer supports them
org.UsePolicies = false;
org.PlanType = PlanType.TeamsAnnually;
await organizationRepository.ReplaceAsync(org);
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
Assert.Empty(actualPolicyDetails);
}
[Theory, DatabaseData]
public async Task GetPolicyDetailsByUserId_IgnoresDisabledPolicies(
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationRepository organizationRepository,
IPolicyRepository policyRepository)
{
// Arrange
var user = await userRepository.CreateTestUserAsync();
var org = await CreateEnterpriseOrg(organizationRepository);
await organizationUserRepository.CreateTestOrganizationUserAsync(org, user);
await policyRepository.CreateAsync(new Policy
{
OrganizationId = org.Id,
Enabled = false,
Type = PolicyType.SingleOrg,
});
// Act
var actualPolicyDetails = await policyRepository.GetPolicyDetailsByUserId(user.Id);
// Assert
Assert.Empty(actualPolicyDetails);
}
private class TestPolicyData : IPolicyDataModel
{
public bool BoolSetting { get; set; }
public int IntSetting { get; set; }
}
private Task<Organization> CreateEnterpriseOrg(IOrganizationRepository organizationRepository)
=> organizationRepository.CreateAsync(new Organization
{
Name = Guid.NewGuid().ToString(),
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULl
PlanType = PlanType.EnterpriseAnnually,
UsePolicies = true
});
}

View File

@@ -0,0 +1,142 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Repositories;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class ProviderUserRepositoryTests
{
[Theory, DatabaseData]
public async Task GetManyOrganizationDetailsByUserAsync_ShouldPopulatePropertiesCorrectly(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository,
ISsoConfigRepository ssoConfigRepository)
{
var user = await userRepository.CreateTestUserAsync();
var organizationWithSso = await organizationRepository.CreateTestOrganizationAsync();
var organizationWithoutSso = await organizationRepository.CreateTestOrganizationAsync();
var provider = await providerRepository.CreateAsync(new Provider
{
Name = "Test Provider",
Enabled = true,
Type = ProviderType.Msp
});
var providerUser = await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user.Id,
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin
});
var providerOrganizationWithSso = await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organizationWithSso.Id
});
var providerOrganizationWithoutSso = await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organizationWithoutSso.Id
});
// Create SSO configuration for first organization only
var serializedSsoConfigData = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize();
var ssoConfig = await ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organizationWithSso.Id,
Enabled = true,
Data = serializedSsoConfigData
});
var results = (await providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed)).ToList();
Assert.Equal(2, results.Count);
var orgWithSsoDetails = results.Single(r => r.OrganizationId == organizationWithSso.Id);
var orgWithoutSsoDetails = results.Single(r => r.OrganizationId == organizationWithoutSso.Id);
// Verify all properties for both organizations
AssertProviderOrganizationDetails(orgWithSsoDetails, organizationWithSso, user, provider, providerUser);
AssertProviderOrganizationDetails(orgWithoutSsoDetails, organizationWithoutSso, user, provider, providerUser);
// Organization without SSO should have null SSO properties
Assert.Null(orgWithoutSsoDetails.SsoEnabled);
Assert.Null(orgWithoutSsoDetails.SsoConfig);
// Organization with SSO should have SSO properties populated
Assert.True(orgWithSsoDetails.SsoEnabled);
Assert.NotNull(orgWithSsoDetails.SsoConfig);
Assert.Equal(serializedSsoConfigData, orgWithSsoDetails.SsoConfig);
}
private static void AssertProviderOrganizationDetails(
ProviderUserOrganizationDetails actual,
Organization expectedOrganization,
User expectedUser,
Provider expectedProvider,
ProviderUser expectedProviderUser)
{
// Organization properties
Assert.Equal(expectedOrganization.Id, actual.OrganizationId);
Assert.Equal(expectedUser.Id, actual.UserId);
Assert.Equal(expectedOrganization.Name, actual.Name);
Assert.Equal(expectedOrganization.UsePolicies, actual.UsePolicies);
Assert.Equal(expectedOrganization.UseSso, actual.UseSso);
Assert.Equal(expectedOrganization.UseKeyConnector, actual.UseKeyConnector);
Assert.Equal(expectedOrganization.UseScim, actual.UseScim);
Assert.Equal(expectedOrganization.UseGroups, actual.UseGroups);
Assert.Equal(expectedOrganization.UseDirectory, actual.UseDirectory);
Assert.Equal(expectedOrganization.UseEvents, actual.UseEvents);
Assert.Equal(expectedOrganization.UseTotp, actual.UseTotp);
Assert.Equal(expectedOrganization.Use2fa, actual.Use2fa);
Assert.Equal(expectedOrganization.UseApi, actual.UseApi);
Assert.Equal(expectedOrganization.UseResetPassword, actual.UseResetPassword);
Assert.Equal(expectedOrganization.UsersGetPremium, actual.UsersGetPremium);
Assert.Equal(expectedOrganization.UseCustomPermissions, actual.UseCustomPermissions);
Assert.Equal(expectedOrganization.SelfHost, actual.SelfHost);
Assert.Equal(expectedOrganization.Seats, actual.Seats);
Assert.Equal(expectedOrganization.MaxCollections, actual.MaxCollections);
Assert.Equal(expectedOrganization.MaxStorageGb, actual.MaxStorageGb);
Assert.Equal(expectedOrganization.Identifier, actual.Identifier);
Assert.Equal(expectedOrganization.PublicKey, actual.PublicKey);
Assert.Equal(expectedOrganization.PrivateKey, actual.PrivateKey);
Assert.Equal(expectedOrganization.Enabled, actual.Enabled);
Assert.Equal(expectedOrganization.PlanType, actual.PlanType);
Assert.Equal(expectedOrganization.LimitCollectionCreation, actual.LimitCollectionCreation);
Assert.Equal(expectedOrganization.LimitCollectionDeletion, actual.LimitCollectionDeletion);
Assert.Equal(expectedOrganization.LimitItemDeletion, actual.LimitItemDeletion);
Assert.Equal(expectedOrganization.AllowAdminAccessToAllCollectionItems, actual.AllowAdminAccessToAllCollectionItems);
Assert.Equal(expectedOrganization.UseRiskInsights, actual.UseRiskInsights);
Assert.Equal(expectedOrganization.UseOrganizationDomains, actual.UseOrganizationDomains);
Assert.Equal(expectedOrganization.UseAdminSponsoredFamilies, actual.UseAdminSponsoredFamilies);
Assert.Equal(expectedOrganization.UseAutomaticUserConfirmation, actual.UseAutomaticUserConfirmation);
// Provider-specific properties
Assert.Equal(expectedProvider.Id, actual.ProviderId);
Assert.Equal(expectedProvider.Name, actual.ProviderName);
Assert.Equal(expectedProvider.Type, actual.ProviderType);
Assert.Equal(expectedProviderUser.Id, actual.ProviderUserId);
Assert.Equal(expectedProviderUser.Status, actual.Status);
Assert.Equal(expectedProviderUser.Type, actual.Type);
}
}

View File

@@ -1,6 +1,7 @@
using AspNetCoreRateLimit;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Push;
using Bit.Core.Platform.PushRegistration.Internal;
using Bit.Core.Repositories;

View File

@@ -0,0 +1,250 @@
#nullable enable
using System.Text.Json;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Core.Utilities;
using Bit.Notifications;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.SignalR;
using NSubstitute;
namespace Notifications.Test;
[SutProviderCustomize]
[NotificationCustomize(false)]
public class HubHelpersTest
{
[Theory]
[BitAutoData]
public async Task SendNotificationToHubAsync_NotificationPushNotificationGlobal_NothingSent(
SutProvider<HubHelpers> sutProvider,
NotificationPushNotification notification,
string contextId, CancellationToken cancellationToke)
{
notification.Global = true;
notification.InstallationId = null;
notification.UserId = null;
notification.OrganizationId = null;
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToke);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).Group(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData]
public async Task
SendNotificationToHubAsync_NotificationPushNotificationInstallationIdProvidedClientTypeAll_SentToGroupInstallation(
SutProvider<HubHelpers> sutProvider,
NotificationPushNotification notification,
string contextId, CancellationToken cancellationToken)
{
notification.UserId = null;
notification.OrganizationId = null;
notification.ClientType = ClientType.All;
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"Installation_{notification.InstallationId!.Value.ToString()}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async Task
SendNotificationToHubAsync_NotificationPushNotificationInstallationIdProvidedClientTypeNotAll_SentToGroupInstallationClientType(
ClientType clientType, SutProvider<HubHelpers> sutProvider,
NotificationPushNotification notification,
string contextId, CancellationToken cancellationToken)
{
notification.UserId = null;
notification.OrganizationId = null;
notification.ClientType = clientType;
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"Installation_ClientType_{notification.InstallationId!.Value}_{clientType}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData(false)]
[BitAutoData(true)]
public async Task SendNotificationToHubAsync_NotificationPushNotificationUserIdProvidedClientTypeAll_SentToUser(
bool organizationIdProvided, SutProvider<HubHelpers> sutProvider,
NotificationPushNotification notification,
string contextId, CancellationToken cancellationToken)
{
notification.InstallationId = null;
notification.ClientType = ClientType.All;
if (!organizationIdProvided)
{
notification.OrganizationId = null;
}
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.User(notification.UserId!.Value.ToString())
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).Group(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData(false, ClientType.Browser)]
[BitAutoData(false, ClientType.Desktop)]
[BitAutoData(false, ClientType.Mobile)]
[BitAutoData(false, ClientType.Web)]
[BitAutoData(true, ClientType.Browser)]
[BitAutoData(true, ClientType.Desktop)]
[BitAutoData(true, ClientType.Mobile)]
[BitAutoData(true, ClientType.Web)]
public async Task
SendNotificationToHubAsync_NotificationPushNotificationUserIdProvidedClientTypeNotAll_SentToGroupUserClientType(
bool organizationIdProvided, ClientType clientType, SutProvider<HubHelpers> sutProvider,
NotificationPushNotification notification,
string contextId, CancellationToken cancellationToken)
{
notification.InstallationId = null;
notification.ClientType = clientType;
if (!organizationIdProvided)
{
notification.OrganizationId = null;
}
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"UserClientType_{notification.UserId!.Value}_{clientType}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData]
public async Task
SendNotificationToHubAsync_NotificationPushNotificationOrganizationIdProvidedClientTypeAll_SentToGroupOrganization(
SutProvider<HubHelpers> sutProvider, string contextId,
NotificationPushNotification notification,
CancellationToken cancellationToken)
{
notification.UserId = null;
notification.InstallationId = null;
notification.ClientType = ClientType.All;
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"Organization_{notification.OrganizationId!.Value}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
[Theory]
[BitAutoData(ClientType.Browser)]
[BitAutoData(ClientType.Desktop)]
[BitAutoData(ClientType.Mobile)]
[BitAutoData(ClientType.Web)]
public async Task
SendNotificationToHubAsync_NotificationPushNotificationOrganizationIdProvidedClientTypeNotAll_SentToGroupOrganizationClientType(
ClientType clientType, SutProvider<HubHelpers> sutProvider, string contextId,
NotificationPushNotification notification,
CancellationToken cancellationToken)
{
notification.UserId = null;
notification.InstallationId = null;
notification.ClientType = clientType;
var json = ToNotificationJson(notification, PushType.Notification, contextId);
await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken);
sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
await sutProvider.GetDependency<IHubContext<NotificationsHub>>().Clients.Received(1)
.Group($"OrganizationClientType_{notification.OrganizationId!.Value}_{clientType}")
.Received(1)
.SendCoreAsync("ReceiveMessage", Arg.Is<object?[]>(objects =>
objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0],
PushType.Notification, contextId)),
cancellationToken);
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0).User(Arg.Any<string>());
sutProvider.GetDependency<IHubContext<AnonymousNotificationsHub>>().Clients.Received(0)
.Group(Arg.Any<string>());
}
private static string ToNotificationJson(object payload, PushType type, string contextId)
{
var notification = new PushNotificationData<object>(type, payload, contextId);
return JsonSerializer.Serialize(notification, JsonHelpers.IgnoreWritingNull);
}
private static bool IsNotificationPushNotificationEqual(NotificationPushNotification expected, object? actual,
PushType type, string contextId)
{
if (actual is not PushNotificationData<NotificationPushNotification> pushNotificationData)
{
return false;
}
return pushNotificationData.Type == type &&
pushNotificationData.ContextId == contextId &&
expected.Id == pushNotificationData.Payload.Id &&
expected.UserId == pushNotificationData.Payload.UserId &&
expected.OrganizationId == pushNotificationData.Payload.OrganizationId &&
expected.ClientType == pushNotificationData.Payload.ClientType &&
expected.RevisionDate == pushNotificationData.Payload.RevisionDate;
}
}

View File

@@ -18,5 +18,7 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Notifications\Notifications.csproj" />
<ProjectReference Include="..\Common\Common.csproj" />
<ProjectReference Include="..\Core.Test\Core.Test.csproj" />
</ItemGroup>
</Project>