1
0
mirror of https://github.com/bitwarden/server synced 2026-01-01 16:13:33 +00:00

Merge branch 'main' into vault/pm-25957/sharing-cipher-to-org

This commit is contained in:
Nick Krantz
2025-11-17 08:37:39 -06:00
committed by GitHub
303 changed files with 32048 additions and 2257 deletions

View File

@@ -0,0 +1,197 @@
using System.Net;
using Bit.Api.AdminConsole.Authorization;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.Models.Request.Organizations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Api;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Xunit;
namespace Bit.Api.IntegrationTest.AdminConsole.Controllers;
public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture<ApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly ApiApplicationFactory _factory;
private readonly LoginHelper _loginHelper;
private Organization _organization = null!;
private string _ownerEmail = null!;
public OrganizationUsersControllerPutResetPasswordTests(ApiApplicationFactory apiFactory)
{
_factory = apiFactory;
_factory.SubstituteService<IFeatureService>(featureService =>
{
featureService
.IsEnabled(FeatureFlagKeys.AccountRecoveryCommand)
.Returns(true);
});
_client = _factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}
public async Task InitializeAsync()
{
_ownerEmail = $"reset-password-test-{Guid.NewGuid()}@example.com";
await _factory.LoginWithNewAccount(_ownerEmail);
(_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023,
ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card);
// Enable reset password and policies for the organization
var organizationRepository = _factory.GetService<IOrganizationRepository>();
_organization.UseResetPassword = true;
_organization.UsePolicies = true;
await organizationRepository.ReplaceAsync(_organization);
// Enable the ResetPassword policy
var policyRepository = _factory.GetService<IPolicyRepository>();
await policyRepository.CreateAsync(new Policy
{
OrganizationId = _organization.Id,
Type = PolicyType.ResetPassword,
Enabled = true,
Data = "{}"
});
}
public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}
/// <summary>
/// Helper method to set the ResetPasswordKey on an organization user, which is required for account recovery
/// </summary>
private async Task SetResetPasswordKeyAsync(OrganizationUser orgUser)
{
var organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
orgUser.ResetPasswordKey = "encrypted-reset-password-key";
await organizationUserRepository.ReplaceAsync(orgUser);
}
[Fact]
public async Task PutResetPassword_AsHigherRole_CanRecoverLowerRole()
{
// Arrange
var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Owner);
await _loginHelper.LoginAsync(ownerEmail);
var (_, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.User);
await SetResetPasswordKeyAsync(targetOrgUser);
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
[Fact]
public async Task PutResetPassword_AsLowerRole_CannotRecoverHigherRole()
{
// Arrange
var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Admin);
await _loginHelper.LoginAsync(adminEmail);
var (_, targetOwnerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.Owner);
await SetResetPasswordKeyAsync(targetOwnerOrgUser);
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOwnerOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var model = await response.Content.ReadFromJsonAsync<ErrorResponseModel>();
Assert.Contains(RecoverAccountAuthorizationHandler.FailureReason, model.Message);
}
[Fact]
public async Task PutResetPassword_CannotRecoverProviderAccount()
{
// Arrange - Create owner who will try to recover the provider account
var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory,
_organization.Id, OrganizationUserType.Owner);
await _loginHelper.LoginAsync(ownerEmail);
// Create a user who is also a provider user
var (targetUserEmail, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(
_factory, _organization.Id, OrganizationUserType.User);
await SetResetPasswordKeyAsync(targetOrgUser);
// Add the target user as a provider user to a different provider
var providerRepository = _factory.GetService<IProviderRepository>();
var providerUserRepository = _factory.GetService<IProviderUserRepository>();
var userRepository = _factory.GetService<IUserRepository>();
var provider = await providerRepository.CreateAsync(new Provider
{
Name = "Test Provider",
BusinessName = "Test Provider Business",
BillingEmail = "provider@example.com",
Type = ProviderType.Msp,
Status = ProviderStatusType.Created,
Enabled = true
});
var targetUser = await userRepository.GetByEmailAsync(targetUserEmail);
Assert.NotNull(targetUser);
await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = targetUser.Id,
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin
});
var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel
{
NewMasterPasswordHash = "new-master-password-hash",
Key = "encrypted-recovery-key"
};
// Act
var response = await _client.PutAsJsonAsync(
$"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password",
resetPasswordRequest);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var model = await response.Content.ReadFromJsonAsync<ErrorResponseModel>();
Assert.Equal(RecoverAccountAuthorizationHandler.ProviderFailureReason, model.Message);
}
}

View File

@@ -67,7 +67,6 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
},
Metadata = new Dictionary<string, object>
@@ -148,7 +147,6 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
{
Policy = new PolicyRequestModel
{
Type = policyType,
Enabled = true,
Data = new Dictionary<string, object>
{
@@ -211,4 +209,192 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
}
}
[Fact]
public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minLength", "not a number" }, // Wrong type - should be int
{ "requireUpper", true }
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Contains("minLength", content); // Verify field name is in error message
}
[Fact]
public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task PutVNext_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minComplexity", "not a number" }, // Wrong type - should be int
{ "minLength", 12 }
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Contains("minComplexity", content); // Verify field name is in error message
}
[Fact]
public async Task PutVNext_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task PutVNext_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
}
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.SingleOrg;
var request = new PolicyRequestModel
{
Enabled = true,
Data = null
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
[Fact]
public async Task PutVNext_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.TwoFactorAuthentication;
var request = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Enabled = true,
Data = null
},
Metadata = null
};
// Act
var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext",
JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
}

View File

@@ -64,6 +64,17 @@ public class MembersControllerTests : IClassFixture<ApiApplicationFactory>, IAsy
var (userEmail4, orgUser4) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id,
OrganizationUserType.Admin);
var collection1 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 1", users:
[
new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = false, Manage = true },
new CollectionAccessSelection { Id = orgUser3.Id, ReadOnly = true, HidePasswords = false, Manage = false }
]);
var collection2 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 2", users:
[
new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = true, Manage = false }
]);
var response = await _client.GetAsync($"/public/members");
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var result = await response.Content.ReadFromJsonAsync<ListResponseModel<MemberResponseModel>>();
@@ -71,23 +82,47 @@ public class MembersControllerTests : IClassFixture<ApiApplicationFactory>, IAsy
Assert.Equal(5, result.Data.Count());
// The owner
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner));
var ownerResult = result.Data.SingleOrDefault(m => m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner);
Assert.NotNull(ownerResult);
Assert.Empty(ownerResult.Collections);
// The custom user
// The custom user with collections
var user1Result = result.Data.Single(m => m.Email == userEmail1);
Assert.Equal(OrganizationUserType.Custom, user1Result.Type);
AssertHelper.AssertPropertyEqual(
new PermissionsModel { AccessImportExport = true, ManagePolicies = true, AccessReports = true },
user1Result.Permissions);
// Verify collections
Assert.NotNull(user1Result.Collections);
Assert.Equal(2, user1Result.Collections.Count());
var user1Collection1 = user1Result.Collections.Single(c => c.Id == collection1.Id);
Assert.False(user1Collection1.ReadOnly);
Assert.False(user1Collection1.HidePasswords);
Assert.True(user1Collection1.Manage);
var user1Collection2 = user1Result.Collections.Single(c => c.Id == collection2.Id);
Assert.False(user1Collection2.ReadOnly);
Assert.True(user1Collection2.HidePasswords);
Assert.False(user1Collection2.Manage);
// Everyone else
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail2 && m.Type == OrganizationUserType.Owner));
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail3 && m.Type == OrganizationUserType.User));
Assert.NotNull(result.Data.SingleOrDefault(m =>
m.Email == userEmail4 && m.Type == OrganizationUserType.Admin));
// The other owner
var user2Result = result.Data.SingleOrDefault(m => m.Email == userEmail2 && m.Type == OrganizationUserType.Owner);
Assert.NotNull(user2Result);
Assert.Empty(user2Result.Collections);
// The user with one collection
var user3Result = result.Data.SingleOrDefault(m => m.Email == userEmail3 && m.Type == OrganizationUserType.User);
Assert.NotNull(user3Result);
Assert.NotNull(user3Result.Collections);
Assert.Single(user3Result.Collections);
var user3Collection1 = user3Result.Collections.Single(c => c.Id == collection1.Id);
Assert.True(user3Collection1.ReadOnly);
Assert.False(user3Collection1.HidePasswords);
Assert.False(user3Collection1.Manage);
// The admin with no collections
var user4Result = result.Data.SingleOrDefault(m => m.Email == userEmail4 && m.Type == OrganizationUserType.Admin);
Assert.NotNull(user4Result);
Assert.Empty(user4Result.Collections);
}
[Fact]

View File

@@ -160,4 +160,86 @@ public class PoliciesControllerTests : IClassFixture<ApiApplicationFactory>, IAs
Assert.Equal(15, data.MinLength);
Assert.Equal(true, data.RequireUpper);
}
[Fact]
public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.MasterPassword;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "minLength", "not a number" }, // Wrong type - should be int
{ "requireUpper", true }
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.SendOptions;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "disableHideEmail", "not a boolean" } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest()
{
// Arrange
var policyType = PolicyType.ResetPassword;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = new Dictionary<string, object>
{
{ "autoEnrollEnabled", 123 } // Wrong type - should be bool
}
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Put_PolicyWithNullData_Success()
{
// Arrange
var policyType = PolicyType.DisableSend;
var request = new PolicyUpdateRequestModel
{
Enabled = true,
Data = null
};
// Act
var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request));
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
}
}

View File

@@ -151,6 +151,28 @@ public static class OrganizationTestHelpers
return group;
}
/// <summary>
/// Creates a collection with optional user and group associations.
/// </summary>
public static async Task<Collection> CreateCollectionAsync(
ApiApplicationFactory factory,
Guid organizationId,
string name,
IEnumerable<CollectionAccessSelection>? users = null,
IEnumerable<CollectionAccessSelection>? groups = null)
{
var collectionRepository = factory.GetService<ICollectionRepository>();
var collection = new Collection
{
OrganizationId = organizationId,
Name = name,
Type = CollectionType.SharedCollection
};
await collectionRepository.CreateAsync(collection, groups, users);
return collection;
}
/// <summary>
/// Enables the Organization Data Ownership policy for the specified organization.
/// </summary>

View File

@@ -0,0 +1,296 @@
using System.Security.Claims;
using Bit.Api.AdminConsole.Authorization;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Test.AutoFixture.OrganizationUserFixtures;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Authorization;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Authorization;
[SutProviderCustomize]
public class RecoverAccountAuthorizationHandlerTests
{
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsProvider_TargetUserNotProvider_Authorized(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null);
MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsNotMemberOrProvider_NotAuthorized(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason);
}
// Pairing of CurrentContextOrganization (current user permissions) and target user role
// Read this as: a ___ can recover the account for a ___
public static IEnumerable<object[]> AuthorizedRoleCombinations => new object[][]
{
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.User],
};
[Theory, BitMemberAutoData(nameof(AuthorizedRoleCombinations))]
public async Task AuthorizeMemberAsync_RecoverEqualOrLesserRoles_TargetUserNotProvider_Authorized(
CurrentContextOrganization currentContextOrganization,
OrganizationUserType targetOrganizationUserType,
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.Type = targetOrganizationUserType;
currentContextOrganization.Id = targetOrganizationUser.OrganizationId;
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
// Pairing of CurrentContextOrganization (current user permissions) and target user role
// Read this as: a ___ cannot recover the account for a ___
public static IEnumerable<object[]> UnauthorizedRoleCombinations => new object[][]
{
// These roles should fail because you cannot recover a greater role
[new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true} }, OrganizationUserType.Admin],
// These roles are never authorized to recover any account
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.User],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Owner],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Admin],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Custom],
[new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.User],
};
[Theory, BitMemberAutoData(nameof(UnauthorizedRoleCombinations))]
public async Task AuthorizeMemberAsync_InvalidRoles_TargetUserNotProvider_Unauthorized(
CurrentContextOrganization currentContextOrganization,
OrganizationUserType targetOrganizationUserType,
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.Type = targetOrganizationUserType;
currentContextOrganization.Id = targetOrganizationUser.OrganizationId;
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_TargetUserIdIsNull_DoesNotBlock(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal)
{
// Arrange
targetOrganizationUser.UserId = null;
MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser);
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
// This should shortcut the provider escalation check
await sutProvider.GetDependency<IProviderUserRepository>().DidNotReceiveWithAnyArgs()
.GetManyByUserAsync(Arg.Any<Guid>());
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserIsMemberOfAllTargetUserProviders_DoesNotBlock(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal,
Guid providerId1,
Guid providerId2)
{
// Arrange
var targetUserProviders = new List<ProviderUser>
{
new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId },
new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId }
};
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByUserAsync(targetOrganizationUser.UserId!.Value)
.Returns(targetUserProviders);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId1)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId2)
.Returns(true);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
Assert.True(context.HasSucceeded);
}
[Theory, BitAutoData]
public async Task HandleRequirementAsync_CurrentUserMissingProviderMembership_Blocks(
SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
[OrganizationUser] OrganizationUser targetOrganizationUser,
ClaimsPrincipal claimsPrincipal,
Guid providerId1,
Guid providerId2)
{
// Arrange
var targetUserProviders = new List<ProviderUser>
{
new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId },
new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId }
};
var context = new AuthorizationHandlerContext(
[new RecoverAccountAuthorizationRequirement()],
claimsPrincipal,
targetOrganizationUser);
MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByUserAsync(targetOrganizationUser.UserId!.Value)
.Returns(targetUserProviders);
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId1)
.Returns(true);
// Not a member of this provider
sutProvider.GetDependency<ICurrentContext>()
.ProviderUser(providerId2)
.Returns(false);
// Act
await sutProvider.Sut.HandleAsync(context);
// Assert
AssertFailed(context, RecoverAccountAuthorizationHandler.ProviderFailureReason);
}
private static void MockOrganizationClaims(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser,
CurrentContextOrganization? currentContextOrganization)
{
sutProvider.GetDependency<IOrganizationContext>()
.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId)
.Returns(currentContextOrganization);
}
private static void MockCurrentUserIsProvider(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser)
{
sutProvider.GetDependency<IOrganizationContext>()
.IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId)
.Returns(true);
}
private static void MockCurrentUserIsOwner(SutProvider<RecoverAccountAuthorizationHandler> sutProvider,
ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser)
{
var currentContextOrganization = new CurrentContextOrganization
{
Id = targetOrganizationUser.OrganizationId,
Type = OrganizationUserType.Owner
};
sutProvider.GetDependency<IOrganizationContext>()
.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId)
.Returns(currentContextOrganization);
}
private static void AssertFailed(AuthorizationHandlerContext context, string expectedMessage)
{
Assert.True(context.HasFailed);
var failureReason = Assert.Single(context.FailureReasons);
Assert.Equal(expectedMessage, failureReason.Message);
}
}

View File

@@ -133,6 +133,29 @@ public class OrganizationIntegrationControllerTests
.DeleteAsync(organizationIntegration);
}
[Theory, BitAutoData]
public async Task PostDeleteAsync_AllParamsProvided_Succeeds(
SutProvider<OrganizationIntegrationController> sutProvider,
Guid organizationId,
OrganizationIntegration organizationIntegration)
{
organizationIntegration.OrganizationId = organizationId;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegration);
await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id);
await sutProvider.GetDependency<IOrganizationIntegrationRepository>().Received(1)
.GetByIdAsync(organizationIntegration.Id);
await sutProvider.GetDependency<IOrganizationIntegrationRepository>().Received(1)
.DeleteAsync(organizationIntegration);
}
[Theory, BitAutoData]
public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound(
SutProvider<OrganizationIntegrationController> sutProvider,

View File

@@ -51,6 +51,36 @@ public class OrganizationIntegrationsConfigurationControllerTests
.DeleteAsync(organizationIntegrationConfiguration);
}
[Theory, BitAutoData]
public async Task PostDeleteAsync_AllParamsProvided_Succeeds(
SutProvider<OrganizationIntegrationConfigurationController> sutProvider,
Guid organizationId,
OrganizationIntegration organizationIntegration,
OrganizationIntegrationConfiguration organizationIntegrationConfiguration)
{
organizationIntegration.OrganizationId = organizationId;
organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegration);
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegrationConfiguration);
await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id);
await sutProvider.GetDependency<IOrganizationIntegrationRepository>().Received(1)
.GetByIdAsync(organizationIntegration.Id);
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.GetByIdAsync(organizationIntegrationConfiguration.Id);
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.DeleteAsync(organizationIntegrationConfiguration);
}
[Theory, BitAutoData]
public async Task DeleteAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound(
SutProvider<OrganizationIntegrationConfigurationController> sutProvider,
@@ -199,27 +229,6 @@ public class OrganizationIntegrationsConfigurationControllerTests
.GetManyByIntegrationAsync(organizationIntegration.Id);
}
// [Theory, BitAutoData]
// public async Task GetAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound(
// SutProvider<OrganizationIntegrationConfigurationController> sutProvider,
// Guid organizationId,
// OrganizationIntegration organizationIntegration)
// {
// organizationIntegration.OrganizationId = organizationId;
// sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
// sutProvider.GetDependency<ICurrentContext>()
// .OrganizationOwner(organizationId)
// .Returns(true);
// sutProvider.GetDependency<IOrganizationIntegrationRepository>()
// .GetByIdAsync(Arg.Any<Guid>())
// .Returns(organizationIntegration);
// sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
// .GetByIdAsync(Arg.Any<Guid>())
// .ReturnsNull();
//
// await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.Empty, Guid.Empty));
// }
//
[Theory, BitAutoData]
public async Task GetAsync_IntegrationDoesNotExist_ThrowsNotFound(
SutProvider<OrganizationIntegrationConfigurationController> sutProvider,
@@ -293,15 +302,16 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(createResponse);
Assert.Equal(expected.Id, createResponse.Id);
Assert.Equal(expected.Configuration, createResponse.Configuration);
Assert.Equal(expected.EventType, createResponse.EventType);
Assert.Equal(expected.Filters, createResponse.Filters);
Assert.Equal(expected.Template, createResponse.Template);
}
[Theory, BitAutoData]
@@ -331,15 +341,16 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(createResponse);
Assert.Equal(expected.Id, createResponse.Id);
Assert.Equal(expected.Configuration, createResponse.Configuration);
Assert.Equal(expected.EventType, createResponse.EventType);
Assert.Equal(expected.Filters, createResponse.Filters);
Assert.Equal(expected.Template, createResponse.Template);
}
[Theory, BitAutoData]
@@ -369,15 +380,16 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model);
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.CreateAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(createResponse);
Assert.Equal(expected.Id, createResponse.Id);
Assert.Equal(expected.Configuration, createResponse.Configuration);
Assert.Equal(expected.EventType, createResponse.EventType);
Assert.Equal(expected.Filters, createResponse.Filters);
Assert.Equal(expected.Template, createResponse.Template);
}
[Theory, BitAutoData]
@@ -575,7 +587,7 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.UpdateAsync(
var updateResponse = await sutProvider.Sut.UpdateAsync(
organizationId,
organizationIntegration.Id,
organizationIntegrationConfiguration.Id,
@@ -583,11 +595,12 @@ public class OrganizationIntegrationsConfigurationControllerTests
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.ReplaceAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(updateResponse);
Assert.Equal(expected.Id, updateResponse.Id);
Assert.Equal(expected.Configuration, updateResponse.Configuration);
Assert.Equal(expected.EventType, updateResponse.EventType);
Assert.Equal(expected.Filters, updateResponse.Filters);
Assert.Equal(expected.Template, updateResponse.Template);
}
@@ -619,7 +632,7 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.UpdateAsync(
var updateResponse = await sutProvider.Sut.UpdateAsync(
organizationId,
organizationIntegration.Id,
organizationIntegrationConfiguration.Id,
@@ -627,11 +640,12 @@ public class OrganizationIntegrationsConfigurationControllerTests
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.ReplaceAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(updateResponse);
Assert.Equal(expected.Id, updateResponse.Id);
Assert.Equal(expected.Configuration, updateResponse.Configuration);
Assert.Equal(expected.EventType, updateResponse.EventType);
Assert.Equal(expected.Filters, updateResponse.Filters);
Assert.Equal(expected.Template, updateResponse.Template);
}
[Theory, BitAutoData]
@@ -662,7 +676,7 @@ public class OrganizationIntegrationsConfigurationControllerTests
sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns(organizationIntegrationConfiguration);
var requestAction = await sutProvider.Sut.UpdateAsync(
var updateResponse = await sutProvider.Sut.UpdateAsync(
organizationId,
organizationIntegration.Id,
organizationIntegrationConfiguration.Id,
@@ -670,11 +684,12 @@ public class OrganizationIntegrationsConfigurationControllerTests
await sutProvider.GetDependency<IOrganizationIntegrationConfigurationRepository>().Received(1)
.ReplaceAsync(Arg.Any<OrganizationIntegrationConfiguration>());
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(requestAction);
Assert.Equal(expected.Id, requestAction.Id);
Assert.Equal(expected.Configuration, requestAction.Configuration);
Assert.Equal(expected.EventType, requestAction.EventType);
Assert.Equal(expected.Template, requestAction.Template);
Assert.IsType<OrganizationIntegrationConfigurationResponseModel>(updateResponse);
Assert.Equal(expected.Id, updateResponse.Id);
Assert.Equal(expected.Configuration, updateResponse.Configuration);
Assert.Equal(expected.EventType, updateResponse.EventType);
Assert.Equal(expected.Filters, updateResponse.Filters);
Assert.Equal(expected.Template, updateResponse.Template);
}
[Theory, BitAutoData]

View File

@@ -1,11 +1,14 @@
using System.Security.Claims;
using Bit.Api.AdminConsole.Authorization;
using Bit.Api.AdminConsole.Controllers;
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Models.Request.Organizations;
using Bit.Api.Vault.AuthorizationHandlers.Collections;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
@@ -16,6 +19,7 @@ using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Api;
using Bit.Core.Models.Business;
using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations;
@@ -30,6 +34,7 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc.ModelBinding;
using NSubstitute;
using Xunit;
@@ -440,4 +445,153 @@ public class OrganizationUsersControllerTests
Assert.Equal("Master Password reset is required, but not provided.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_CallsLegacyPath(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(true);
sutProvider.GetDependency<IUserService>().AdminResetPasswordAsync(Arg.Any<OrganizationUserType>(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<IUserService>().Received(1)
.AdminResetPasswordAsync(OrganizationUserType.Owner, orgId, orgUserId, model.NewMasterPasswordHash, model.Key);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_WhenOrgUserTypeIsNull_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(false);
sutProvider.GetDependency<ICurrentContext>().Organizations.Returns(new List<CurrentContextOrganization>());
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagDisabled_WhenAdminResetPasswordFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false);
sutProvider.GetDependency<ICurrentContext>().OrganizationOwner(orgId).Returns(true);
sutProvider.GetDependency<IUserService>().AdminResetPasswordAsync(Arg.Any<OrganizationUserType>(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error 1" }));
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ModelStateDictionary>>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationUserNotFound_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model,
SutProvider<OrganizationUsersController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns((OrganizationUser)null);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationIdMismatch_ReturnsNotFound(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = Guid.NewGuid();
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<NotFound>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenAuthorizationFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Failed());
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ErrorResponseModel>>(result);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountSucceeds_ReturnsOk(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Success());
sutProvider.GetDependency<IAdminRecoverAccountCommand>()
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success);
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<IAdminRecoverAccountCommand>().Received(1)
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key);
}
[Theory]
[BitAutoData]
public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountFails_ReturnsBadRequest(
Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser,
SutProvider<OrganizationUsersController> sutProvider)
{
organizationUser.OrganizationId = orgId;
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>().GetByIdAsync(orgUserId).Returns(organizationUser);
sutProvider.GetDependency<IAuthorizationService>()
.AuthorizeAsync(
Arg.Any<ClaimsPrincipal>(),
organizationUser,
Arg.Is<IEnumerable<IAuthorizationRequirement>>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement))
.Returns(AuthorizationResult.Success());
sutProvider.GetDependency<IAdminRecoverAccountCommand>()
.RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key)
.Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error message" }));
var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model);
Assert.IsType<BadRequest<ModelStateDictionary>>(result);
}
}

View File

@@ -71,6 +71,26 @@ public class SlackIntegrationControllerTests
await sutProvider.Sut.CreateAsync(string.Empty, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Slack;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == "SlackIntegration_Create"))
.Returns((string?)null);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
@@ -153,6 +173,8 @@ public class SlackIntegrationControllerTests
OrganizationIntegration wrongOrgIntegration)
{
wrongOrgIntegration.Id = integration.Id;
wrongOrgIntegration.Type = IntegrationType.Slack;
wrongOrgIntegration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
@@ -304,6 +326,22 @@ public class SlackIntegrationControllerTests
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}
[Theory, BitAutoData]
public async Task RedirectAsync_CallbackUrlReturnsEmpty_ThrowsBadRequest(
SutProvider<SlackIntegrationController> sutProvider,
Guid organizationId)
{
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == "SlackIntegration_Create"))
.Returns((string?)null);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}
[Theory, BitAutoData]
public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound(
SutProvider<SlackIntegrationController> sutProvider,

View File

@@ -60,6 +60,26 @@ public class TeamsIntegrationControllerTests
Assert.IsType<CreatedResult>(requestAction);
}
[Theory, BitAutoData]
public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest(
SutProvider<TeamsIntegrationController> sutProvider,
OrganizationIntegration integration)
{
integration.Type = IntegrationType.Teams;
integration.Configuration = null;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == "TeamsIntegration_Create"))
.Returns((string?)null);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetByIdAsync(integration.Id)
.Returns(integration);
var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency<TimeProvider>());
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString()));
}
[Theory, BitAutoData]
public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest(
SutProvider<TeamsIntegrationController> sutProvider,
@@ -315,6 +335,30 @@ public class TeamsIntegrationControllerTests
sutProvider.GetDependency<ITeamsService>().Received(1).GetRedirectUrl(Arg.Any<string>(), expectedState.ToString());
}
[Theory, BitAutoData]
public async Task RedirectAsync_CallbackUrlIsEmpty_ThrowsBadRequest(
SutProvider<TeamsIntegrationController> sutProvider,
Guid organizationId,
OrganizationIntegration integration)
{
integration.OrganizationId = organizationId;
integration.Configuration = null;
integration.Type = IntegrationType.Teams;
sutProvider.Sut.Url = Substitute.For<IUrlHelper>();
sutProvider.Sut.Url
.RouteUrl(Arg.Is<UrlRouteContext>(c => c.RouteName == "TeamsIntegration_Create"))
.Returns((string?)null);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(organizationId)
.Returns(true);
sutProvider.GetDependency<IOrganizationIntegrationRepository>()
.GetManyByOrganizationAsync(organizationId)
.Returns([integration]);
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.RedirectAsync(organizationId));
}
[Theory, BitAutoData]
public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest(
SutProvider<TeamsIntegrationController> sutProvider,

View File

@@ -1,14 +1,47 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json;
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Core.Enums;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations;
public class OrganizationIntegrationRequestModelTests
{
[Fact]
public void ToOrganizationIntegration_CreatesNewOrganizationIntegration()
{
var model = new OrganizationIntegrationRequestModel
{
Type = IntegrationType.Hec,
Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token"))
};
var organizationId = Guid.NewGuid();
var organizationIntegration = model.ToOrganizationIntegration(organizationId);
Assert.Equal(organizationIntegration.Type, model.Type);
Assert.Equal(organizationIntegration.Configuration, model.Configuration);
Assert.Equal(organizationIntegration.OrganizationId, organizationId);
}
[Theory, BitAutoData]
public void ToOrganizationIntegration_UpdatesExistingOrganizationIntegration(OrganizationIntegration integration)
{
var model = new OrganizationIntegrationRequestModel
{
Type = IntegrationType.Hec,
Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token"))
};
var organizationIntegration = model.ToOrganizationIntegration(integration);
Assert.Equal(organizationIntegration.Configuration, model.Configuration);
}
[Fact]
public void Validate_CloudBillingSync_ReturnsNotYetSupportedError()
{

View File

@@ -24,11 +24,11 @@ public class SavePolicyRequestTests
currentContext.OrganizationOwner(organizationId).Returns(true);
var testData = new Dictionary<string, object> { { "test", "value" } };
var policyType = PolicyType.TwoFactorAuthentication;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.TwoFactorAuthentication,
Enabled = true,
Data = testData
},
@@ -36,7 +36,7 @@ public class SavePolicyRequestTests
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.Equal(PolicyType.TwoFactorAuthentication, result.PolicyUpdate.Type);
@@ -54,7 +54,7 @@ public class SavePolicyRequestTests
}
[Theory, BitAutoData]
public async Task ToSavePolicyModelAsync_WithNullData_HandlesCorrectly(
public async Task ToSavePolicyModelAsync_WithEmptyData_HandlesCorrectly(
Guid organizationId,
Guid userId)
{
@@ -63,19 +63,17 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(false);
var policyType = PolicyType.SingleOrg;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.SingleOrg,
Enabled = false,
Data = null
},
Metadata = null
Enabled = false
}
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.Null(result.PolicyUpdate.Data);
@@ -95,19 +93,17 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(true);
var policyType = PolicyType.SingleOrg;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.SingleOrg,
Enabled = false,
Data = null
},
Metadata = null
Enabled = false
}
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.Null(result.PolicyUpdate.Data);
@@ -128,13 +124,12 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(true);
var policyType = PolicyType.OrganizationDataOwnership;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
Enabled = true
},
Metadata = new Dictionary<string, object>
{
@@ -143,7 +138,7 @@ public class SavePolicyRequestTests
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.IsType<OrganizationModelOwnershipPolicyModel>(result.Metadata);
@@ -152,7 +147,7 @@ public class SavePolicyRequestTests
}
[Theory, BitAutoData]
public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithNullMetadata_ReturnsEmptyMetadata(
public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithEmptyMetadata_ReturnsEmptyMetadata(
Guid organizationId,
Guid userId)
{
@@ -161,19 +156,17 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(true);
var policyType = PolicyType.OrganizationDataOwnership;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
},
Metadata = null
Enabled = true
}
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.NotNull(result);
@@ -200,12 +193,11 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(true);
var policyType = PolicyType.ResetPassword;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.ResetPassword,
Enabled = true,
Data = _complexData
},
@@ -213,7 +205,7 @@ public class SavePolicyRequestTests
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
var deserializedData = JsonSerializer.Deserialize<Dictionary<string, JsonElement>>(result.PolicyUpdate.Data);
@@ -241,13 +233,12 @@ public class SavePolicyRequestTests
currentContext.UserId.Returns(userId);
currentContext.OrganizationOwner(organizationId).Returns(true);
var policyType = PolicyType.MaximumVaultTimeout;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.MaximumVaultTimeout,
Enabled = true,
Data = null
Enabled = true
},
Metadata = new Dictionary<string, object>
{
@@ -256,7 +247,7 @@ public class SavePolicyRequestTests
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.NotNull(result);
@@ -274,20 +265,18 @@ public class SavePolicyRequestTests
currentContext.OrganizationOwner(organizationId).Returns(true);
var errorDictionary = BuildErrorDictionary();
var policyType = PolicyType.OrganizationDataOwnership;
var model = new SavePolicyRequest
{
Policy = new PolicyRequestModel
{
Type = PolicyType.OrganizationDataOwnership,
Enabled = true,
Data = null
Enabled = true
},
Metadata = errorDictionary
};
// Act
var result = await model.ToSavePolicyModelAsync(organizationId, currentContext);
var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext);
// Assert
Assert.NotNull(result);

View File

@@ -0,0 +1,150 @@
using Bit.Api.AdminConsole.Models.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Response;
public class ProfileOrganizationResponseModelTests
{
[Theory, BitAutoData]
public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization)
{
var userId = Guid.NewGuid();
var organizationUserId = Guid.NewGuid();
var providerId = Guid.NewGuid();
var organizationIdsClaimingUser = new[] { organization.Id };
var organizationDetails = new OrganizationUserOrganizationDetails
{
OrganizationId = organization.Id,
UserId = userId,
OrganizationUserId = organizationUserId,
Name = organization.Name,
Enabled = organization.Enabled,
Identifier = organization.Identifier,
PlanType = organization.PlanType,
UsePolicies = organization.UsePolicies,
UseSso = organization.UseSso,
UseKeyConnector = organization.UseKeyConnector,
UseScim = organization.UseScim,
UseGroups = organization.UseGroups,
UseDirectory = organization.UseDirectory,
UseEvents = organization.UseEvents,
UseTotp = organization.UseTotp,
Use2fa = organization.Use2fa,
UseApi = organization.UseApi,
UseResetPassword = organization.UseResetPassword,
UseSecretsManager = organization.UseSecretsManager,
UsePasswordManager = organization.UsePasswordManager,
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,
SelfHost = organization.SelfHost,
Seats = organization.Seats,
MaxCollections = organization.MaxCollections,
MaxStorageGb = organization.MaxStorageGb,
Key = "organization-key",
PublicKey = "public-key",
PrivateKey = "private-key",
LimitCollectionCreation = organization.LimitCollectionCreation,
LimitCollectionDeletion = organization.LimitCollectionDeletion,
LimitItemDeletion = organization.LimitItemDeletion,
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems,
ProviderId = providerId,
ProviderName = "Test Provider",
ProviderType = ProviderType.Msp,
SsoEnabled = true,
SsoConfig = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize(),
SsoExternalId = "external-sso-id",
Permissions = CoreHelpers.ClassToJsonData(new Core.Models.Data.Permissions { ManageUsers = true }),
ResetPasswordKey = "reset-password-key",
FamilySponsorshipFriendlyName = "Family Sponsorship",
FamilySponsorshipLastSyncDate = DateTime.UtcNow.AddDays(-1),
FamilySponsorshipToDelete = false,
FamilySponsorshipValidUntil = DateTime.UtcNow.AddYears(1),
IsAdminInitiated = true,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
AccessSecretsManager = true,
SmSeats = 5,
SmServiceAccounts = 10
};
var result = new ProfileOrganizationResponseModel(organizationDetails, organizationIdsClaimingUser);
Assert.Equal("profileOrganization", result.Object);
Assert.Equal(organization.Id, result.Id);
Assert.Equal(userId, result.UserId);
Assert.Equal(organization.Name, result.Name);
Assert.Equal(organization.Enabled, result.Enabled);
Assert.Equal(organization.Identifier, result.Identifier);
Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType);
Assert.Equal(organization.UsePolicies, result.UsePolicies);
Assert.Equal(organization.UseSso, result.UseSso);
Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector);
Assert.Equal(organization.UseScim, result.UseScim);
Assert.Equal(organization.UseGroups, result.UseGroups);
Assert.Equal(organization.UseDirectory, result.UseDirectory);
Assert.Equal(organization.UseEvents, result.UseEvents);
Assert.Equal(organization.UseTotp, result.UseTotp);
Assert.Equal(organization.Use2fa, result.Use2fa);
Assert.Equal(organization.UseApi, result.UseApi);
Assert.Equal(organization.UseResetPassword, result.UseResetPassword);
Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager);
Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager);
Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium);
Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions);
Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
Assert.Equal(organization.SelfHost, result.SelfHost);
Assert.Equal(organization.Seats, result.Seats);
Assert.Equal(organization.MaxCollections, result.MaxCollections);
Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb);
Assert.Equal(organizationDetails.Key, result.Key);
Assert.True(result.HasPublicAndPrivateKeys);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organizationDetails.ProviderId, result.ProviderId);
Assert.Equal(organizationDetails.ProviderName, result.ProviderName);
Assert.Equal(organizationDetails.ProviderType, result.ProviderType);
Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled);
Assert.True(result.KeyConnectorEnabled);
Assert.Equal("https://keyconnector.example.com", result.KeyConnectorUrl);
Assert.Equal(MemberDecryptionType.KeyConnector, result.SsoMemberDecryptionType);
Assert.True(result.SsoBound);
Assert.Equal(organizationDetails.Status, result.Status);
Assert.Equal(organizationDetails.Type, result.Type);
Assert.Equal(organizationDetails.OrganizationUserId, result.OrganizationUserId);
Assert.True(result.UserIsClaimedByOrganization);
Assert.NotNull(result.Permissions);
Assert.True(result.ResetPasswordEnrolled);
Assert.Equal(organizationDetails.AccessSecretsManager, result.AccessSecretsManager);
Assert.Equal(organizationDetails.FamilySponsorshipFriendlyName, result.FamilySponsorshipFriendlyName);
Assert.Equal(organizationDetails.FamilySponsorshipLastSyncDate, result.FamilySponsorshipLastSyncDate);
Assert.Equal(organizationDetails.FamilySponsorshipToDelete, result.FamilySponsorshipToDelete);
Assert.Equal(organizationDetails.FamilySponsorshipValidUntil, result.FamilySponsorshipValidUntil);
Assert.True(result.IsAdminInitiated);
Assert.False(result.FamilySponsorshipAvailable);
}
}

View File

@@ -0,0 +1,129 @@
using Bit.Api.AdminConsole.Models.Response;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Enums;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Models.Response;
public class ProfileProviderOrganizationResponseModelTests
{
[Theory, BitAutoData]
public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization)
{
var userId = Guid.NewGuid();
var providerId = Guid.NewGuid();
var providerUserId = Guid.NewGuid();
var organizationDetails = new ProviderUserOrganizationDetails
{
OrganizationId = organization.Id,
UserId = userId,
Name = organization.Name,
Enabled = organization.Enabled,
Identifier = organization.Identifier,
PlanType = organization.PlanType,
UsePolicies = organization.UsePolicies,
UseSso = organization.UseSso,
UseKeyConnector = organization.UseKeyConnector,
UseScim = organization.UseScim,
UseGroups = organization.UseGroups,
UseDirectory = organization.UseDirectory,
UseEvents = organization.UseEvents,
UseTotp = organization.UseTotp,
Use2fa = organization.Use2fa,
UseApi = organization.UseApi,
UseResetPassword = organization.UseResetPassword,
UseSecretsManager = organization.UseSecretsManager,
UsePasswordManager = organization.UsePasswordManager,
UsersGetPremium = organization.UsersGetPremium,
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,
SelfHost = organization.SelfHost,
Seats = organization.Seats,
MaxCollections = organization.MaxCollections,
MaxStorageGb = organization.MaxStorageGb,
Key = "provider-org-key",
PublicKey = "public-key",
PrivateKey = "private-key",
LimitCollectionCreation = organization.LimitCollectionCreation,
LimitCollectionDeletion = organization.LimitCollectionDeletion,
LimitItemDeletion = organization.LimitItemDeletion,
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems,
ProviderId = providerId,
ProviderName = "Test MSP Provider",
ProviderType = ProviderType.Msp,
SsoEnabled = true,
SsoConfig = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption
}.Serialize(),
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin,
ProviderUserId = providerUserId
};
var result = new ProfileProviderOrganizationResponseModel(organizationDetails);
Assert.Equal("profileProviderOrganization", result.Object);
Assert.Equal(organization.Id, result.Id);
Assert.Equal(userId, result.UserId);
Assert.Equal(organization.Name, result.Name);
Assert.Equal(organization.Enabled, result.Enabled);
Assert.Equal(organization.Identifier, result.Identifier);
Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType);
Assert.Equal(organization.UsePolicies, result.UsePolicies);
Assert.Equal(organization.UseSso, result.UseSso);
Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector);
Assert.Equal(organization.UseScim, result.UseScim);
Assert.Equal(organization.UseGroups, result.UseGroups);
Assert.Equal(organization.UseDirectory, result.UseDirectory);
Assert.Equal(organization.UseEvents, result.UseEvents);
Assert.Equal(organization.UseTotp, result.UseTotp);
Assert.Equal(organization.Use2fa, result.Use2fa);
Assert.Equal(organization.UseApi, result.UseApi);
Assert.Equal(organization.UseResetPassword, result.UseResetPassword);
Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager);
Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager);
Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium);
Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions);
Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
Assert.Equal(organization.SelfHost, result.SelfHost);
Assert.Equal(organization.Seats, result.Seats);
Assert.Equal(organization.MaxCollections, result.MaxCollections);
Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb);
Assert.Equal(organizationDetails.Key, result.Key);
Assert.True(result.HasPublicAndPrivateKeys);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organizationDetails.ProviderId, result.ProviderId);
Assert.Equal(organizationDetails.ProviderName, result.ProviderName);
Assert.Equal(organizationDetails.ProviderType, result.ProviderType);
Assert.Equal(OrganizationUserStatusType.Confirmed, result.Status);
Assert.Equal(OrganizationUserType.Owner, result.Type);
Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled);
Assert.False(result.KeyConnectorEnabled);
Assert.Null(result.KeyConnectorUrl);
Assert.Equal(MemberDecryptionType.TrustedDeviceEncryption, result.SsoMemberDecryptionType);
Assert.False(result.SsoBound);
Assert.NotNull(result.Permissions);
Assert.False(result.Permissions.ManageUsers);
Assert.False(result.ResetPasswordEnrolled);
Assert.False(result.AccessSecretsManager);
}
}

View File

@@ -0,0 +1,87 @@
using Bit.Api.AdminConsole.Public.Controllers;
using Bit.Api.AdminConsole.Public.Models.Request;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Context;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.AdminConsole.Public.Controllers;
[ControllerCustomize(typeof(PoliciesController))]
[SutProviderCustomize]
public class PoliciesControllerTests
{
[Theory]
[BitAutoData]
public async Task Put_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
Guid organizationId,
PolicyType policyType,
PolicyUpdateRequestModel model,
Policy policy,
SutProvider<PoliciesController> sutProvider)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.OrganizationId.Returns(organizationId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
sutProvider.GetDependency<IVNextSavePolicyCommand>()
.SaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
await sutProvider.Sut.Put(policyType, model);
// Assert
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.OrganizationId == organizationId &&
m.PolicyUpdate.Type == policyType &&
m.PolicyUpdate.Enabled == model.Enabled.GetValueOrDefault() &&
m.PerformedBy is SystemUser));
}
[Theory]
[BitAutoData]
public async Task Put_WhenPolicyValidatorsRefactorDisabled_UsesLegacySavePolicyCommand(
Guid organizationId,
PolicyType policyType,
PolicyUpdateRequestModel model,
Policy policy,
SutProvider<PoliciesController> sutProvider)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.OrganizationId.Returns(organizationId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(false);
sutProvider.GetDependency<ISavePolicyCommand>()
.SaveAsync(Arg.Any<PolicyUpdate>())
.Returns(policy);
// Act
await sutProvider.Sut.Put(policyType, model);
// Assert
await sutProvider.GetDependency<ISavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<PolicyUpdate>(p =>
p.OrganizationId == organizationId &&
p.Type == policyType &&
p.Enabled == model.Enabled));
}
}

View File

@@ -0,0 +1,800 @@
using System.Security.Claims;
using Bit.Api.Billing.Controllers;
using Bit.Core;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Queries.Interfaces;
using Bit.Core.Models.Business;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Test.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Api.Test.Billing.Controllers;
[SubscriptionInfoCustomize]
public class AccountsControllerTests : IDisposable
{
private const string TestMilestone2CouponId = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount;
private readonly IUserService _userService;
private readonly IFeatureService _featureService;
private readonly IPaymentService _paymentService;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly IUserAccountKeysQuery _userAccountKeysQuery;
private readonly GlobalSettings _globalSettings;
private readonly AccountsController _sut;
public AccountsControllerTests()
{
_userService = Substitute.For<IUserService>();
_featureService = Substitute.For<IFeatureService>();
_paymentService = Substitute.For<IPaymentService>();
_twoFactorIsEnabledQuery = Substitute.For<ITwoFactorIsEnabledQuery>();
_userAccountKeysQuery = Substitute.For<IUserAccountKeysQuery>();
_globalSettings = new GlobalSettings { SelfHosted = false };
_sut = new AccountsController(
_userService,
_twoFactorIsEnabledQuery,
_userAccountKeysQuery,
_featureService
);
}
public void Dispose()
{
_sut?.Dispose();
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WhenFeatureFlagEnabled_IncludesDiscount(
User user,
SubscriptionInfo subscriptionInfo,
UserLicense license)
{
// Arrange
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = TestMilestone2CouponId,
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe; // User has payment gateway
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Equal(20m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WhenFeatureFlagDisabled_ExcludesDiscount(
User user,
SubscriptionInfo subscriptionInfo,
UserLicense license)
{
// Arrange
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = TestMilestone2CouponId,
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(false);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe; // User has payment gateway
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount); // Should be null when feature flag is disabled
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithNonMatchingCouponId_ExcludesDiscount(
User user,
SubscriptionInfo subscriptionInfo,
UserLicense license)
{
// Arrange
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = "different-coupon-id", // Non-matching coupon ID
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe; // User has payment gateway
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount); // Should be null when coupon ID doesn't match
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WhenSelfHosted_ReturnsBasicResponse(User user)
{
// Arrange
var selfHostedSettings = new GlobalSettings { SelfHosted = true };
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
// Act
var result = await _sut.GetSubscriptionAsync(selfHostedSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount);
await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any<User>());
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WhenNoGateway_ExcludesDiscount(User user, UserLicense license)
{
// Arrange
user.Gateway = null; // No gateway configured
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_userService.GenerateLicenseAsync(user).Returns(license);
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount); // Should be null when no gateway
await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any<User>());
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithInactiveDiscount_ExcludesDiscount(
User user,
SubscriptionInfo subscriptionInfo,
UserLicense license)
{
// Arrange
subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = TestMilestone2CouponId,
Active = false, // Inactive discount
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe; // User has payment gateway
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount); // Should be null when discount is inactive
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_FullPipeline_ConvertsStripeDiscountToApiResponse(
User user,
UserLicense license)
{
// Arrange - Create a Stripe Discount object with real structure
var stripeDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 25m,
AmountOff = 1400, // 1400 cents = $14.00
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "prod_premium", "prod_families" }
}
},
End = null // Active discount
};
// Convert Stripe Discount to BillingCustomerDiscount (simulating what StripePaymentService does)
var billingDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount);
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingDiscount
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Verify full pipeline conversion
Assert.NotNull(result);
Assert.NotNull(result.CustomerDiscount);
// Verify Stripe data correctly converted to API response
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.True(result.CustomerDiscount.Active);
Assert.Equal(25m, result.CustomerDiscount.PercentOff);
// Verify cents-to-dollars conversion (1400 cents -> $14.00)
Assert.Equal(14.00m, result.CustomerDiscount.AmountOff);
// Verify AppliesTo products are preserved
Assert.NotNull(result.CustomerDiscount.AppliesTo);
Assert.Equal(2, result.CustomerDiscount.AppliesTo.Count());
Assert.Contains("prod_premium", result.CustomerDiscount.AppliesTo);
Assert.Contains("prod_families", result.CustomerDiscount.AppliesTo);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_FullPipeline_WithFeatureFlagToggle_ControlsVisibility(
User user,
UserLicense license)
{
// Arrange - Create Stripe Discount
var stripeDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 20m
},
End = null
};
var billingDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount);
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingDiscount
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act & Assert - Feature flag ENABLED
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
var resultWithFlag = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
Assert.NotNull(resultWithFlag.CustomerDiscount);
// Act & Assert - Feature flag DISABLED
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(false);
var resultWithoutFlag = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
Assert.Null(resultWithoutFlag.CustomerDiscount);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_CompletePipelineFromStripeToApiResponse(
User user,
UserLicense license)
{
// Arrange - Create a real Stripe Discount object as it would come from Stripe API
var stripeDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 30m,
AmountOff = 2000, // 2000 cents = $20.00
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "prod_premium", "prod_families", "prod_teams" }
}
},
End = null // Active discount (no end date)
};
// Step 1: Map Stripe Discount through SubscriptionInfo.BillingCustomerDiscount
// This simulates what StripePaymentService.GetSubscriptionAsync does
var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount);
// Verify the mapping worked correctly
Assert.Equal(TestMilestone2CouponId, billingCustomerDiscount.Id);
Assert.True(billingCustomerDiscount.Active);
Assert.Equal(30m, billingCustomerDiscount.PercentOff);
Assert.Equal(20.00m, billingCustomerDiscount.AmountOff); // Converted from cents
Assert.NotNull(billingCustomerDiscount.AppliesTo);
Assert.Equal(3, billingCustomerDiscount.AppliesTo.Count);
// Step 2: Create SubscriptionInfo with the mapped discount
// This simulates what StripePaymentService returns
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingCustomerDiscount
};
// Step 3: Set up controller dependencies
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act - Step 4: Call AccountsController.GetSubscriptionAsync
// This exercises the complete pipeline:
// - Retrieves subscriptionInfo from paymentService (with discount from Stripe)
// - Maps through SubscriptionInfo.BillingCustomerDiscount (already done above)
// - Filters in SubscriptionResponseModel constructor (based on feature flag, coupon ID, active status)
// - Returns via AccountsController
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Verify the complete pipeline worked end-to-end
Assert.NotNull(result);
Assert.NotNull(result.CustomerDiscount);
// Verify Stripe Discount → SubscriptionInfo.BillingCustomerDiscount mapping
// (verified above, but confirming it made it through)
// Verify SubscriptionInfo.BillingCustomerDiscount → SubscriptionResponseModel.BillingCustomerDiscount filtering
// The filter should pass because:
// - includeMilestone2Discount = true (feature flag enabled)
// - subscription.CustomerDiscount != null
// - subscription.CustomerDiscount.Id == Milestone2SubscriptionDiscount
// - subscription.CustomerDiscount.Active = true
Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id);
Assert.True(result.CustomerDiscount.Active);
Assert.Equal(30m, result.CustomerDiscount.PercentOff);
Assert.Equal(20.00m, result.CustomerDiscount.AmountOff); // Verify cents-to-dollars conversion
// Verify AppliesTo products are preserved through the entire pipeline
Assert.NotNull(result.CustomerDiscount.AppliesTo);
Assert.Equal(3, result.CustomerDiscount.AppliesTo.Count());
Assert.Contains("prod_premium", result.CustomerDiscount.AppliesTo);
Assert.Contains("prod_families", result.CustomerDiscount.AppliesTo);
Assert.Contains("prod_teams", result.CustomerDiscount.AppliesTo);
// Verify the payment service was called correctly
await _paymentService.Received(1).GetSubscriptionAsync(user);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_MultipleDiscountsInSubscription_PrefersCustomerDiscount(
User user,
UserLicense license)
{
// Arrange - Create Stripe subscription with multiple discounts
// Customer discount should be preferred over subscription discounts
var customerDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 30m,
AmountOff = null
},
End = null
};
var subscriptionDiscount1 = new Discount
{
Coupon = new Coupon
{
Id = "other-coupon-1",
PercentOff = 10m
},
End = null
};
var subscriptionDiscount2 = new Discount
{
Coupon = new Coupon
{
Id = "other-coupon-2",
PercentOff = 15m
},
End = null
};
// Map through SubscriptionInfo.BillingCustomerDiscount
var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customerDiscount);
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingCustomerDiscount
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Should use customer discount, not subscription discounts
Assert.NotNull(result);
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id);
Assert.Equal(30m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_BothPercentOffAndAmountOffPresent_HandlesEdgeCase(
User user,
UserLicense license)
{
// Arrange - Edge case: Stripe coupon with both PercentOff and AmountOff
// This tests the scenario mentioned in BillingCustomerDiscountTests.cs line 212-232
var stripeDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 25m,
AmountOff = 2000, // 2000 cents = $20.00
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "prod_premium" }
}
},
End = null
};
// Map through SubscriptionInfo.BillingCustomerDiscount
var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount);
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingCustomerDiscount
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Both values should be preserved through the pipeline
Assert.NotNull(result);
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id);
Assert.Equal(25m, result.CustomerDiscount.PercentOff);
Assert.Equal(20.00m, result.CustomerDiscount.AmountOff); // Converted from cents
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_BillingSubscriptionMapsThroughPipeline(
User user,
UserLicense license)
{
// Arrange - Create Stripe subscription with subscription details
var stripeSubscription = new Subscription
{
Id = "sub_test123",
Status = "active",
TrialStart = DateTime.UtcNow.AddDays(-30),
TrialEnd = DateTime.UtcNow.AddDays(-20),
CanceledAt = null,
CancelAtPeriodEnd = false,
CollectionMethod = "charge_automatically"
};
// Map through SubscriptionInfo.BillingSubscription
var billingSubscription = new SubscriptionInfo.BillingSubscription(stripeSubscription);
var subscriptionInfo = new SubscriptionInfo
{
Subscription = billingSubscription,
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = TestMilestone2CouponId,
Active = true,
PercentOff = 20m
}
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Verify BillingSubscription mapped through pipeline
Assert.NotNull(result);
Assert.NotNull(result.Subscription);
Assert.Equal("active", result.Subscription.Status);
Assert.Equal(14, result.Subscription.GracePeriod); // charge_automatically = 14 days
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_BillingUpcomingInvoiceMapsThroughPipeline(
User user,
UserLicense license)
{
// Arrange - Create Stripe invoice for upcoming invoice
var stripeInvoice = new Invoice
{
AmountDue = 2000, // 2000 cents = $20.00
Created = DateTime.UtcNow.AddDays(1)
};
// Map through SubscriptionInfo.BillingUpcomingInvoice
var billingUpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice);
var subscriptionInfo = new SubscriptionInfo
{
UpcomingInvoice = billingUpcomingInvoice,
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = TestMilestone2CouponId,
Active = true,
PercentOff = 20m
}
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Verify BillingUpcomingInvoice mapped through pipeline
Assert.NotNull(result);
Assert.NotNull(result.UpcomingInvoice);
Assert.Equal(20.00m, result.UpcomingInvoice.Amount); // Converted from cents
Assert.NotNull(result.UpcomingInvoice.Date);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_IntegrationTest_CompletePipelineWithAllComponents(
User user,
UserLicense license)
{
// Arrange - Complete Stripe objects for full pipeline test
var stripeDiscount = new Discount
{
Coupon = new Coupon
{
Id = TestMilestone2CouponId,
PercentOff = 20m,
AmountOff = 1000, // $10.00
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "prod_premium", "prod_families" }
}
},
End = null
};
var stripeSubscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically"
};
var stripeInvoice = new Invoice
{
AmountDue = 1500, // $15.00
Created = DateTime.UtcNow.AddDays(7)
};
// Map through SubscriptionInfo (simulating StripePaymentService)
var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount);
var billingSubscription = new SubscriptionInfo.BillingSubscription(stripeSubscription);
var billingUpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice);
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = billingCustomerDiscount,
Subscription = billingSubscription,
UpcomingInvoice = billingUpcomingInvoice
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true);
_paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo);
_userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license);
user.Gateway = GatewayType.Stripe;
// Act - Full pipeline: Stripe → SubscriptionInfo → SubscriptionResponseModel → API response
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Verify all components mapped correctly through the pipeline
Assert.NotNull(result);
// Verify discount
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id);
Assert.Equal(20m, result.CustomerDiscount.PercentOff);
Assert.Equal(10.00m, result.CustomerDiscount.AmountOff);
Assert.NotNull(result.CustomerDiscount.AppliesTo);
Assert.Equal(2, result.CustomerDiscount.AppliesTo.Count());
// Verify subscription
Assert.NotNull(result.Subscription);
Assert.Equal("active", result.Subscription.Status);
Assert.Equal(14, result.Subscription.GracePeriod);
// Verify upcoming invoice
Assert.NotNull(result.UpcomingInvoice);
Assert.Equal(15.00m, result.UpcomingInvoice.Amount);
Assert.NotNull(result.UpcomingInvoice.Date);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_SelfHosted_WithDiscountFlagEnabled_NeverIncludesDiscount(User user)
{
// Arrange - Self-hosted user with discount flag enabled (should still return null)
var selfHostedSettings = new GlobalSettings { SelfHosted = true };
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); // Flag enabled
// Act
var result = await _sut.GetSubscriptionAsync(selfHostedSettings, _paymentService);
// Assert - Should never include discount for self-hosted, even with flag enabled
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount);
await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any<User>());
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_NullGateway_WithDiscountFlagEnabled_NeverIncludesDiscount(
User user,
UserLicense license)
{
// Arrange - User with null gateway and discount flag enabled (should still return null)
user.Gateway = null; // No gateway configured
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
_sut.ControllerContext = new ControllerContext
{
HttpContext = new DefaultHttpContext { User = claimsPrincipal }
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_userService.GenerateLicenseAsync(user).Returns(license);
_featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); // Flag enabled
// Act
var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService);
// Assert - Should never include discount when no gateway, even with flag enabled
Assert.NotNull(result);
Assert.Null(result.CustomerDiscount);
await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any<User>());
}
}

View File

@@ -1,10 +1,15 @@
using System.Security.Claims;
using System.Text.Json;
using Bit.Api.AdminConsole.Controllers;
using Bit.Api.AdminConsole.Models.Request;
using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Context;
@@ -455,4 +460,98 @@ public class PoliciesControllerTests
Assert.Equal(enabledPolicy.Type, expectedPolicy.Type);
Assert.Equal(enabledPolicy.Enabled, expectedPolicy.Enabled);
}
[Theory]
[BitAutoData]
public async Task PutVNext_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
SutProvider<PoliciesController> sutProvider, Guid orgId,
SavePolicyRequest model, Policy policy, Guid userId)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(userId);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(orgId)
.Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
sutProvider.GetDependency<IVNextSavePolicyCommand>()
.SaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
var result = await sutProvider.Sut.PutVNext(orgId, policy.Type, model);
// Assert
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(
m => m.PolicyUpdate.OrganizationId == orgId &&
m.PolicyUpdate.Type == policy.Type &&
m.PolicyUpdate.Enabled == model.Policy.Enabled &&
m.PerformedBy.UserId == userId &&
m.PerformedBy.IsOrganizationOwnerOrProvider == true));
await sutProvider.GetDependency<ISavePolicyCommand>()
.DidNotReceiveWithAnyArgs()
.VNextSaveAsync(default);
Assert.NotNull(result);
Assert.Equal(policy.Id, result.Id);
Assert.Equal(policy.Type, result.Type);
}
[Theory]
[BitAutoData]
public async Task PutVNext_WhenPolicyValidatorsRefactorDisabled_UsesSavePolicyCommand(
SutProvider<PoliciesController> sutProvider, Guid orgId,
SavePolicyRequest model, Policy policy, Guid userId)
{
// Arrange
policy.Data = null;
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(userId);
sutProvider.GetDependency<ICurrentContext>()
.OrganizationOwner(orgId)
.Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(false);
sutProvider.GetDependency<ISavePolicyCommand>()
.VNextSaveAsync(Arg.Any<SavePolicyModel>())
.Returns(policy);
// Act
var result = await sutProvider.Sut.PutVNext(orgId, policy.Type, model);
// Assert
await sutProvider.GetDependency<ISavePolicyCommand>()
.Received(1)
.VNextSaveAsync(Arg.Is<SavePolicyModel>(
m => m.PolicyUpdate.OrganizationId == orgId &&
m.PolicyUpdate.Type == policy.Type &&
m.PolicyUpdate.Enabled == model.Policy.Enabled &&
m.PerformedBy.UserId == userId &&
m.PerformedBy.IsOrganizationOwnerOrProvider == true));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.DidNotReceiveWithAnyArgs()
.SaveAsync(default);
Assert.NotNull(result);
Assert.Equal(policy.Id, result.Id);
Assert.Equal(policy.Type, result.Type);
}
}

View File

@@ -1,4 +1,5 @@
using Bit.Api.Dirt.Controllers;
using Bit.Api.Dirt.Models.Response;
using Bit.Core.Context;
using Bit.Core.Dirt.Entities;
using Bit.Core.Dirt.Models.Data;
@@ -39,7 +40,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -262,7 +264,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -365,7 +368,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -597,7 +601,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -812,7 +817,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]
@@ -1050,7 +1056,8 @@ public class OrganizationReportControllerTests
// Assert
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal(expectedReport, okResult.Value);
var expectedResponse = new OrganizationReportResponseModel(expectedReport);
Assert.Equivalent(expectedResponse, okResult.Value);
}
[Theory, BitAutoData]

View File

@@ -0,0 +1,400 @@
using Bit.Api.Models.Response;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Entities;
using Bit.Core.Models.Business;
using Bit.Test.Common.AutoFixture.Attributes;
using Stripe;
using Xunit;
namespace Bit.Api.Test.Models.Response;
public class SubscriptionResponseModelTests
{
[Theory]
[BitAutoData]
public void Constructor_IncludeMilestone2DiscountTrueMatchingCouponId_ReturnsDiscount(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.True(result.CustomerDiscount.Active);
Assert.Equal(20m, result.CustomerDiscount.PercentOff);
Assert.Null(result.CustomerDiscount.AmountOff);
Assert.NotNull(result.CustomerDiscount.AppliesTo);
Assert.Single(result.CustomerDiscount.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_IncludeMilestone2DiscountTrueNonMatchingCouponId_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = "different-coupon-id", // Non-matching coupon ID
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_IncludeMilestone2DiscountFalseMatchingCouponId_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: false);
// Assert - Should be null because includeMilestone2Discount is false
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_NullCustomerDiscount_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = null
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_AmountOffDiscountMatchingCouponId_ReturnsDiscount(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
Active = true,
PercentOff = null,
AmountOff = 14.00m, // Already converted from cents in BillingCustomerDiscount
AppliesTo = new List<string>()
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Null(result.CustomerDiscount.PercentOff);
Assert.Equal(14.00m, result.CustomerDiscount.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_DefaultIncludeMilestone2DiscountParameter_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
Active = true,
PercentOff = 20m
}
};
// Act - Using default parameter (includeMilestone2Discount defaults to false)
var result = new SubscriptionResponseModel(user, subscriptionInfo, license);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_NullDiscountIdIncludeMilestone2DiscountTrue_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = null, // Null discount ID
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_MatchingCouponIdInactiveDiscount_ReturnsNull(
User user,
UserLicense license)
{
// Arrange
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID
Active = false, // Inactive discount
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "product1" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_UserOnly_SetsBasicProperties(User user)
{
// Arrange
user.Storage = 5368709120; // 5 GB in bytes
user.MaxStorageGb = (short)10;
user.PremiumExpirationDate = DateTime.UtcNow.AddMonths(12);
// Act
var result = new SubscriptionResponseModel(user);
// Assert
Assert.NotNull(result.StorageName);
Assert.Equal(5.0, result.StorageGb);
Assert.Equal((short)10, result.MaxStorageGb);
Assert.Equal(user.PremiumExpirationDate, result.Expiration);
Assert.Null(result.License);
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_UserAndLicense_IncludesLicense(User user, UserLicense license)
{
// Arrange
user.Storage = 1073741824; // 1 GB in bytes
user.MaxStorageGb = (short)5;
// Act
var result = new SubscriptionResponseModel(user, license);
// Assert
Assert.NotNull(result.License);
Assert.Equal(license, result.License);
Assert.Equal(1.0, result.StorageGb);
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_NullStorage_SetsStorageToZero(User user)
{
// Arrange
user.Storage = null;
// Act
var result = new SubscriptionResponseModel(user);
// Assert
Assert.Null(result.StorageName);
Assert.Equal(0, result.StorageGb);
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_NullLicense_ExcludesLicense(User user)
{
// Act
var result = new SubscriptionResponseModel(user, null);
// Assert
Assert.Null(result.License);
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public void Constructor_BothPercentOffAndAmountOffPresent_HandlesEdgeCase(
User user,
UserLicense license)
{
// Arrange - Edge case: Both PercentOff and AmountOff present
// This tests the scenario where Stripe coupon has both discount types
var subscriptionInfo = new SubscriptionInfo
{
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
Active = true,
PercentOff = 25m,
AmountOff = 20.00m, // Already converted from cents
AppliesTo = new List<string> { "prod_premium" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert - Both values should be preserved
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Equal(25m, result.CustomerDiscount.PercentOff);
Assert.Equal(20.00m, result.CustomerDiscount.AmountOff);
Assert.NotNull(result.CustomerDiscount.AppliesTo);
Assert.Single(result.CustomerDiscount.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_WithSubscriptionAndInvoice_MapsAllProperties(
User user,
UserLicense license)
{
// Arrange - Test with Subscription, UpcomingInvoice, and CustomerDiscount
var stripeSubscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically"
};
var stripeInvoice = new Invoice
{
AmountDue = 1500, // 1500 cents = $15.00
Created = DateTime.UtcNow.AddDays(7)
};
var subscriptionInfo = new SubscriptionInfo
{
Subscription = new SubscriptionInfo.BillingSubscription(stripeSubscription),
UpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice),
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
Active = true,
PercentOff = 20m,
AmountOff = null,
AppliesTo = new List<string> { "prod_premium" }
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert - Verify all properties are mapped correctly
Assert.NotNull(result.Subscription);
Assert.Equal("active", result.Subscription.Status);
Assert.Equal(14, result.Subscription.GracePeriod); // charge_automatically = 14 days
Assert.NotNull(result.UpcomingInvoice);
Assert.Equal(15.00m, result.UpcomingInvoice.Amount);
Assert.NotNull(result.UpcomingInvoice.Date);
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.True(result.CustomerDiscount.Active);
Assert.Equal(20m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public void Constructor_WithNullSubscriptionAndInvoice_HandlesNullsGracefully(
User user,
UserLicense license)
{
// Arrange - Test with null Subscription and UpcomingInvoice
var subscriptionInfo = new SubscriptionInfo
{
Subscription = null,
UpcomingInvoice = null,
CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
Active = true,
PercentOff = 20m
}
};
// Act
var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true);
// Assert - Null Subscription and UpcomingInvoice should be handled gracefully
Assert.Null(result.Subscription);
Assert.Null(result.UpcomingInvoice);
Assert.NotNull(result.CustomerDiscount);
}
}

View File

@@ -0,0 +1,221 @@
using Bit.Api.Models.Public.Request;
using Bit.Api.Models.Public.Response;
using Bit.Api.Utilities.DiagnosticTools;
using Bit.Core;
using Bit.Core.Models.Data;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.Utilities.DiagnosticTools;
public class EventDiagnosticLoggerTests
{
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_FeatureFlagEnabled_LogsInformation(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var request = new EventFilterRequestModel()
{
Start = DateTime.UtcNow.AddMinutes(-3),
End = DateTime.UtcNow,
ActingUserId = Guid.NewGuid(),
ItemId = Guid.NewGuid(),
};
var newestEvent = Substitute.For<IEvent>();
newestEvent.Date.Returns(DateTime.UtcNow);
var middleEvent = Substitute.For<IEvent>();
middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1));
var oldestEvent = Substitute.For<IEvent>();
oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-3));
var eventResponses = new List<EventResponseModel>
{
new (newestEvent),
new (middleEvent),
new (oldestEvent)
};
var response = new PagedListResponseModel<EventResponseModel>(eventResponses, "continuation-token");
// Act
logger.LogAggregateData(featureService, organizationId, response, request);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains($"Event count:{eventResponses.Count}") &&
o.ToString().Contains($"newest record:{newestEvent.Date:O}") &&
o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") &&
o.ToString().Contains("HasMore:True") &&
o.ToString().Contains($"Start:{request.Start:o}") &&
o.ToString().Contains($"End:{request.End:o}") &&
o.ToString().Contains($"ActingUserId:{request.ActingUserId}") &&
o.ToString().Contains($"ItemId:{request.ItemId}"))
,
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_FeatureFlagDisabled_DoesNotLog(
Guid organizationId,
EventFilterRequestModel request)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false);
PagedListResponseModel<EventResponseModel> dummy = null;
// Act
logger.LogAggregateData(featureService, organizationId, dummy, request);
// Assert
logger.DidNotReceive().Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithPublicResponse_EmptyData_LogsZeroCount(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var request = new EventFilterRequestModel()
{
Start = null,
End = null,
ActingUserId = null,
ItemId = null,
ContinuationToken = null,
};
var response = new PagedListResponseModel<EventResponseModel>(new List<EventResponseModel>(), null);
// Act
logger.LogAggregateData(featureService, organizationId, response, request);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains("Event count:0") &&
o.ToString().Contains("HasMore:False")),
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_FeatureFlagDisabled_DoesNotLog(Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false);
// Act
logger.LogAggregateData(featureService, organizationId, null, null, null, null);
// Assert
logger.DidNotReceive().Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_EmptyData_LogsZeroCount(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
Bit.Api.Models.Response.EventResponseModel[] emptyEvents = [];
// Act
logger.LogAggregateData(featureService, organizationId, emptyEvents, null, null, null);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains("Event count:0") &&
o.ToString().Contains("HasMore:False")),
null,
Arg.Any<Func<object, Exception, string>>());
}
[Theory, BitAutoData]
public void LogAggregateData_WithInternalResponse_FeatureFlagEnabled_LogsInformation(
Guid organizationId)
{
// Arrange
var logger = Substitute.For<ILogger>();
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true);
var newestEvent = Substitute.For<IEvent>();
newestEvent.Date.Returns(DateTime.UtcNow);
var middleEvent = Substitute.For<IEvent>();
middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1));
var oldestEvent = Substitute.For<IEvent>();
oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-2));
var events = new List<Bit.Api.Models.Response.EventResponseModel>
{
new (newestEvent),
new (middleEvent),
new (oldestEvent)
};
var queryStart = DateTime.UtcNow.AddMinutes(-3);
var queryEnd = DateTime.UtcNow;
const string continuationToken = "continuation-token";
// Act
logger.LogAggregateData(featureService, organizationId, events, continuationToken, queryStart, queryEnd);
// Assert
logger.Received(1).Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o =>
o.ToString().Contains(organizationId.ToString()) &&
o.ToString().Contains($"Event count:{events.Count}") &&
o.ToString().Contains($"newest record:{newestEvent.Date:O}") &&
o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") &&
o.ToString().Contains("HasMore:True") &&
o.ToString().Contains($"Start:{queryStart:o}") &&
o.ToString().Contains($"End:{queryEnd:o}"))
,
null,
Arg.Any<Func<object, Exception, string>>());
}
}

View File

@@ -285,6 +285,10 @@ public class SyncControllerTests
providerUserRepository
.GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails);
foreach (var p in providerUserOrganizationDetails)
{
p.SsoConfig = null;
}
providerUserRepository
.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed)
.Returns(providerUserOrganizationDetails);

View File

@@ -0,0 +1,391 @@
using Bit.Billing.Controllers;
using Bit.Billing.Models;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Clients;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using BitPayLight.Models.Invoice;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
using Transaction = Bit.Core.Entities.Transaction;
namespace Bit.Billing.Test.Controllers;
using static BitPayConstants;
public class BitPayControllerTests
{
private readonly GlobalSettings _globalSettings = new();
private readonly IBitPayClient _bitPayClient = Substitute.For<IBitPayClient>();
private readonly ITransactionRepository _transactionRepository = Substitute.For<ITransactionRepository>();
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IMailService _mailService = Substitute.For<IMailService>();
private readonly IPaymentService _paymentService = Substitute.For<IPaymentService>();
private readonly IPremiumUserBillingService _premiumUserBillingService =
Substitute.For<IPremiumUserBillingService>();
private const string _validWebhookKey = "valid-webhook-key";
private const string _invalidWebhookKey = "invalid-webhook-key";
public BitPayControllerTests()
{
var bitPaySettings = new GlobalSettings.BitPaySettings { WebhookKey = _validWebhookKey };
_globalSettings.BitPay = bitPaySettings;
}
private BitPayController CreateController() => new(
_globalSettings,
_bitPayClient,
_transactionRepository,
_organizationRepository,
_userRepository,
_providerRepository,
_mailService,
_paymentService,
Substitute.For<ILogger<BitPayController>>(),
_premiumUserBillingService);
[Fact]
public async Task PostIpn_InvalidKey_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var result = await controller.PostIpn(eventModel, _invalidWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid key", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NullKey_ThrowsException()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
await Assert.ThrowsAsync<ArgumentNullException>(() => controller.PostIpn(eventModel, null!));
}
[Fact]
public async Task PostIpn_EmptyKey_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var result = await controller.PostIpn(eventModel, string.Empty);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid key", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NonUsdCurrency_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(currency: "EUR");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Cannot process non-USD payments", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_NullPosData_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: null!);
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_EmptyPosData_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: "");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_PosDataWithoutAccountCredit_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: "organizationId:550e8400-e29b-41d4-a716-446655440000");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_PosDataWithoutValidId_BadRequest()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(posData: PosDataKeys.AccountCredit);
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var badRequestResult = Assert.IsType<BadRequestObjectResult>(result);
Assert.Equal("Invalid POS data", badRequestResult.Value);
}
[Fact]
public async Task PostIpn_IncompleteInvoice_Ok()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice(status: "paid");
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal("Waiting for invoice to be completed", okResult.Value);
}
[Fact]
public async Task PostIpn_ExistingTransaction_Ok()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var invoice = CreateValidInvoice();
var existingTransaction = new Transaction { GatewayId = invoice.Id };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns(existingTransaction);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
var okResult = Assert.IsType<OkObjectResult>(result);
Assert.Equal("Invoice already processed", okResult.Value);
}
[Fact]
public async Task PostIpn_ValidOrganizationTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var organizationId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}");
var organization = new Organization { Id = organizationId, BillingEmail = "billing@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
_paymentService.CreditAccountAsync(organization, Arg.Any<decimal>()).Returns(true);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.OrganizationId == organizationId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _organizationRepository.Received(1).ReplaceAsync(organization);
await _mailService.Received(1).SendAddedCreditAsync("billing@example.com", 100.00m);
}
[Fact]
public async Task PostIpn_ValidUserTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var userId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}");
var user = new User { Id = userId, Email = "user@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_userRepository.GetByIdAsync(userId).Returns(user);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.UserId == userId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _premiumUserBillingService.Received(1).Credit(user, 100.00m);
await _mailService.Received(1).SendAddedCreditAsync("user@example.com", 100.00m);
}
[Fact]
public async Task PostIpn_ValidProviderTransaction_Success()
{
var controller = CreateController();
var eventModel = CreateValidEventModel();
var providerId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}");
var provider = new Provider { Id = providerId, BillingEmail = "provider@example.com" };
_bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice);
_transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null);
_providerRepository.GetByIdAsync(providerId).Returns(Task.FromResult(provider));
_paymentService.CreditAccountAsync(provider, Arg.Any<decimal>()).Returns(true);
var result = await controller.PostIpn(eventModel, _validWebhookKey);
Assert.IsType<OkResult>(result);
await _transactionRepository.Received(1).CreateAsync(Arg.Is<Transaction>(t =>
t.ProviderId == providerId &&
t.Type == TransactionType.Credit &&
t.Gateway == GatewayType.BitPay &&
t.PaymentMethodType == PaymentMethodType.BitPay));
await _providerRepository.Received(1).ReplaceAsync(provider);
await _mailService.Received(1).SendAddedCreditAsync("provider@example.com", 100.00m);
}
[Fact]
public void GetIdsFromPosData_ValidOrganizationId_ReturnsCorrectId()
{
var controller = CreateController();
var organizationId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Equal(organizationId, result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_ValidUserId_ReturnsCorrectId()
{
var controller = CreateController();
var userId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Equal(userId, result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_ValidProviderId_ReturnsCorrectId()
{
var controller = CreateController();
var providerId = Guid.NewGuid();
var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Equal(providerId, result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_InvalidGuid_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: "organizationId:invalid-guid,{PosDataKeys.AccountCredit}");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_NullPosData_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: null!);
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
[Fact]
public void GetIdsFromPosData_EmptyPosData_ReturnsNull()
{
var controller = CreateController();
var invoice = CreateValidInvoice(posData: "");
var result = controller.GetIdsFromPosData(invoice);
Assert.Null(result.OrganizationId);
Assert.Null(result.UserId);
Assert.Null(result.ProviderId);
}
private static BitPayEventModel CreateValidEventModel(string invoiceId = "test-invoice-id")
{
return new BitPayEventModel
{
Event = new BitPayEventModel.EventModel { Code = 1005, Name = "invoice_confirmed" },
Data = new BitPayEventModel.InvoiceDataModel { Id = invoiceId }
};
}
private static Invoice CreateValidInvoice(string invoiceId = "test-invoice-id", string status = "complete",
string currency = "USD", decimal price = 100.00m,
string posData = "organizationId:550e8400-e29b-41d4-a716-446655440000,accountCredit:1")
{
return new Invoice
{
Id = invoiceId,
Status = status,
Currency = currency,
Price = (double)price,
PosData = posData,
CurrentTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(),
Transactions =
[
new InvoiceTransaction
{
Type = null,
Confirmations = "1",
ReceivedTime = DateTime.UtcNow.ToString("O")
}
]
};
}
}

View File

@@ -0,0 +1,234 @@
using Bit.Billing.Jobs;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Quartz;
using Xunit;
namespace Bit.Billing.Test.Jobs;
public class ProviderOrganizationDisableJobTests
{
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly ILogger<ProviderOrganizationDisableJob> _logger;
private readonly ProviderOrganizationDisableJob _sut;
public ProviderOrganizationDisableJobTests()
{
_providerOrganizationRepository = Substitute.For<IProviderOrganizationRepository>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_logger = Substitute.For<ILogger<ProviderOrganizationDisableJob>>();
_sut = new ProviderOrganizationDisableJob(
_providerOrganizationRepository,
_organizationDisableCommand,
_logger);
}
[Fact]
public async Task Execute_NoOrganizations_LogsAndReturns()
{
// Arrange
var providerId = Guid.NewGuid();
var context = CreateJobExecutionContext(providerId, DateTime.UtcNow);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns((ICollection<ProviderOrganizationOrganizationDetails>)null);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
}
[Fact]
public async Task Execute_WithOrganizations_DisablesAllOrganizations()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var org1Id = Guid.NewGuid();
var org2Id = Guid.NewGuid();
var org3Id = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = org1Id },
new() { OrganizationId = org2Id },
new() { OrganizationId = org3Id }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_WithExpirationDate_PassesDateToDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = new DateTime(2025, 12, 31, 23, 59, 59);
var orgId = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = orgId }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(orgId, expirationDate);
}
[Fact]
public async Task Execute_WithNullExpirationDate_PassesNullToDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var orgId = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = orgId }
};
var context = CreateJobExecutionContext(providerId, null);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.Received(1).DisableAsync(orgId, null);
}
[Fact]
public async Task Execute_OneOrganizationFails_ContinuesProcessingOthers()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var org1Id = Guid.NewGuid();
var org2Id = Guid.NewGuid();
var org3Id = Guid.NewGuid();
var organizations = new List<ProviderOrganizationOrganizationDetails>
{
new() { OrganizationId = org1Id },
new() { OrganizationId = org2Id },
new() { OrganizationId = org3Id }
};
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
// Make org2 fail
_organizationDisableCommand.DisableAsync(org2Id, Arg.Any<DateTime?>())
.Throws(new Exception("Database error"));
// Act
await _sut.Execute(context);
// Assert - all three should be attempted
await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any<DateTime?>());
await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_ManyOrganizations_ProcessesWithLimitedConcurrency()
{
// Arrange
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
// Create 20 organizations
var organizations = Enumerable.Range(1, 20)
.Select(_ => new ProviderOrganizationOrganizationDetails { OrganizationId = Guid.NewGuid() })
.ToList();
var context = CreateJobExecutionContext(providerId, expirationDate);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(organizations);
var concurrentCalls = 0;
var maxConcurrentCalls = 0;
var lockObj = new object();
_organizationDisableCommand.DisableAsync(Arg.Any<Guid>(), Arg.Any<DateTime?>())
.Returns(callInfo =>
{
lock (lockObj)
{
concurrentCalls++;
if (concurrentCalls > maxConcurrentCalls)
{
maxConcurrentCalls = concurrentCalls;
}
}
return Task.Delay(50).ContinueWith(_ =>
{
lock (lockObj)
{
concurrentCalls--;
}
});
});
// Act
await _sut.Execute(context);
// Assert
Assert.True(maxConcurrentCalls <= 5, $"Expected max concurrency of 5, but got {maxConcurrentCalls}");
await _organizationDisableCommand.Received(20).DisableAsync(Arg.Any<Guid>(), Arg.Any<DateTime?>());
}
[Fact]
public async Task Execute_EmptyOrganizationsList_DoesNotCallDisableCommand()
{
// Arrange
var providerId = Guid.NewGuid();
var context = CreateJobExecutionContext(providerId, DateTime.UtcNow);
_providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId)
.Returns(new List<ProviderOrganizationOrganizationDetails>());
// Act
await _sut.Execute(context);
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
}
private static IJobExecutionContext CreateJobExecutionContext(Guid providerId, DateTime? expirationDate)
{
var context = Substitute.For<IJobExecutionContext>();
var jobDataMap = new JobDataMap
{
{ "providerId", providerId.ToString() },
{ "expirationDate", expirationDate?.ToString("O") }
};
context.MergedJobDataMap.Returns(jobDataMap);
return context;
}
}

View File

@@ -1,10 +1,15 @@
using Bit.Billing.Constants;
using Bit.Billing.Jobs;
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Extensions;
using Bit.Core.Services;
using NSubstitute;
using Quartz;
using Stripe;
using Xunit;
@@ -16,6 +21,10 @@ public class SubscriptionDeletedHandlerTests
private readonly IUserService _userService;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IScheduler _scheduler;
private readonly SubscriptionDeletedHandler _sut;
public SubscriptionDeletedHandlerTests()
@@ -24,11 +33,19 @@ public class SubscriptionDeletedHandlerTests
_userService = Substitute.For<IUserService>();
_stripeEventUtilityService = Substitute.For<IStripeEventUtilityService>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_providerRepository = Substitute.For<IProviderRepository>();
_providerService = Substitute.For<IProviderService>();
_schedulerFactory = Substitute.For<ISchedulerFactory>();
_scheduler = Substitute.For<IScheduler>();
_schedulerFactory.GetScheduler().Returns(_scheduler);
_sut = new SubscriptionDeletedHandler(
_stripeEventService,
_userService,
_stripeEventUtilityService,
_organizationDisableCommand);
_organizationDisableCommand,
_providerRepository,
_providerService,
_schedulerFactory);
}
[Fact]
@@ -59,6 +76,7 @@ public class SubscriptionDeletedHandlerTests
// Assert
await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default);
await _userService.DidNotReceiveWithAnyArgs().DisablePremiumAsync(default, default);
await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default);
}
[Fact]
@@ -192,4 +210,120 @@ public class SubscriptionDeletedHandlerTests
await _organizationDisableCommand.DidNotReceiveWithAnyArgs()
.DisableAsync(default, default);
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_DisablesProviderAndQueuesJob()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var provider = new Provider
{
Id = providerId,
Enabled = true
};
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await _sut.HandleAsync(stripeEvent);
// Assert
Assert.False(provider.Enabled);
await _providerService.Received(1).UpdateAsync(provider);
await _scheduler.Received(1).ScheduleJob(
Arg.Is<IJobDetail>(j => j.JobType == typeof(ProviderOrganizationDisableJob)),
Arg.Any<ITrigger>());
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_ProviderNotFound_DoesNotThrow()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns((Provider)null);
// Act & Assert - Should not throw
await _sut.HandleAsync(stripeEvent);
// Assert
await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default);
await _scheduler.DidNotReceiveWithAnyArgs().ScheduleJob(default, default);
}
[Fact]
public async Task HandleAsync_ProviderSubscriptionCanceled_QueuesJobWithCorrectParameters()
{
// Arrange
var stripeEvent = new Event();
var providerId = Guid.NewGuid();
var expirationDate = DateTime.UtcNow.AddDays(30);
var provider = new Provider
{
Id = providerId,
Enabled = true
};
var subscription = new Subscription
{
Status = StripeSubscriptionStatus.Canceled,
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { CurrentPeriodEnd = expirationDate }
]
},
Metadata = new Dictionary<string, string> { { "providerId", providerId.ToString() } }
};
_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, providerId));
_providerRepository.GetByIdAsync(providerId).Returns(provider);
// Act
await _sut.HandleAsync(stripeEvent);
// Assert
Assert.False(provider.Enabled);
await _providerService.Received(1).UpdateAsync(provider);
await _scheduler.Received(1).ScheduleJob(
Arg.Is<IJobDetail>(j =>
j.JobType == typeof(ProviderOrganizationDisableJob) &&
j.JobDataMap.GetString("providerId") == providerId.ToString() &&
j.JobDataMap.GetString("expirationDate") == expirationDate.ToString("O")),
Arg.Is<ITrigger>(t => t.Key.Name == $"disable-trigger-{providerId}"));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using MailKit.Security;
using Microsoft.Extensions.Logging;

View File

@@ -1,4 +1,6 @@
using AutoFixture;
using System.Reflection;
using AutoFixture;
using AutoFixture.Xunit2;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
@@ -23,6 +25,7 @@ public class CurrentContextOrganizationCustomization : ICustomization
}
}
[AttributeUsage(AttributeTargets.Method)]
public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribute
{
public Guid Id { get; set; }
@@ -38,3 +41,19 @@ public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribut
AccessSecretsManager = AccessSecretsManager
};
}
public class CurrentContextOrganizationAttribute : CustomizeAttribute
{
public Guid Id { get; set; }
public OrganizationUserType Type { get; set; } = OrganizationUserType.User;
public Permissions Permissions { get; set; } = new();
public bool AccessSecretsManager { get; set; } = false;
public override ICustomization GetCustomization(ParameterInfo _) => new CurrentContextOrganizationCustomization
{
Id = Id,
Type = Type,
Permissions = Permissions,
AccessSecretsManager = AccessSecretsManager
};
}

View File

@@ -20,6 +20,20 @@ public class IntegrationTemplateContextTests
Assert.Equal(expected, sut.EventMessage);
}
[Theory, BitAutoData]
public void DateIso8601_ReturnsIso8601FormattedDate(EventMessage eventMessage)
{
var testDate = new DateTime(2025, 10, 27, 13, 30, 0, DateTimeKind.Utc);
eventMessage.Date = testDate;
var sut = new IntegrationTemplateContext(eventMessage);
var result = sut.DateIso8601;
Assert.Equal("2025-10-27T13:30:00.0000000Z", result);
// Verify it's valid ISO 8601
Assert.True(DateTime.TryParse(result, out _));
}
[Theory, BitAutoData]
public void UserName_WhenUserIsSet_ReturnsName(EventMessage eventMessage, User user)
{

View File

@@ -0,0 +1,296 @@
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AutoFixture.OrganizationUserFixtures;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.AccountRecovery;
[SutProviderCustomize]
public class AdminRecoverAccountCommandTests
{
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_Success(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
User user,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
SetupValidUser(sutProvider, user, organizationUser);
SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword);
// Act
var result = await sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key);
// Assert
Assert.True(result.Succeeded);
await AssertSuccessAsync(sutProvider, user, key, organization, organizationUser);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_OrganizationDoesNotExist_ThrowsBadRequest(
[OrganizationUser] OrganizationUser organizationUser,
string newMasterPassword,
string key,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
var orgId = Guid.NewGuid();
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgId)
.Returns((Organization)null);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(orgId, organizationUser, newMasterPassword, key));
Assert.Equal("Organization does not allow password reset.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_OrganizationDoesNotAllowResetPassword_ThrowsBadRequest(
string newMasterPassword,
string key,
Organization organization,
[OrganizationUser] OrganizationUser organizationUser,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
organization.UseResetPassword = false;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Organization does not allow password reset.", exception.Message);
}
public static IEnumerable<object[]> InvalidPolicies => new object[][]
{
[new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null]
};
[Theory]
[BitMemberAutoData(nameof(InvalidPolicies))]
public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest(
Policy resetPasswordPolicy,
string newMasterPassword,
string key,
Organization organization,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
.Returns(resetPasswordPolicy);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, new OrganizationUser { Id = Guid.NewGuid() },
newMasterPassword, key));
Assert.Equal("Organization does not have the password reset policy enabled.", exception.Message);
}
public static IEnumerable<object[]> InvalidOrganizationUsers()
{
// Make an organization so we can use its Id
var organization = new Fixture().Create<Organization>();
var nonConfirmed = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = organization.Id,
Status = OrganizationUserStatusType.Invited
};
yield return [nonConfirmed, organization];
var wrongOrganization = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = Guid.NewGuid(), // Different org
ResetPasswordKey = "test-key",
UserId = Guid.NewGuid(),
};
yield return [wrongOrganization, organization];
var nullResetPasswordKey = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = null,
UserId = Guid.NewGuid(),
};
yield return [nullResetPasswordKey, organization];
var emptyResetPasswordKey = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = "",
UserId = Guid.NewGuid(),
};
yield return [emptyResetPasswordKey, organization];
var nullUserId = new OrganizationUser
{
Status = OrganizationUserStatusType.Confirmed,
OrganizationId = organization.Id,
ResetPasswordKey = "test-key",
UserId = null,
};
yield return [nullUserId, organization];
}
[Theory]
[BitMemberAutoData(nameof(InvalidOrganizationUsers))]
public async Task RecoverAccountAsync_OrganizationUserIsInvalid_ThrowsBadRequest(
OrganizationUser organizationUser,
Organization organization,
string newMasterPassword,
string key,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Organization User not valid", exception.Message);
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_UserDoesNotExist_ThrowsNotFoundException(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(organizationUser.UserId!.Value)
.Returns((User)null);
// Act & Assert
await Assert.ThrowsAsync<NotFoundException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
}
[Theory]
[BitAutoData]
public async Task RecoverAccountAsync_UserUsesKeyConnector_ThrowsBadRequest(
string newMasterPassword,
string key,
Organization organization,
OrganizationUser organizationUser,
User user,
SutProvider<AdminRecoverAccountCommand> sutProvider)
{
// Arrange
SetupValidOrganization(sutProvider, organization);
SetupValidPolicy(sutProvider, organization);
SetupValidOrganizationUser(organizationUser, organization.Id);
user.UsesKeyConnector = true;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(organizationUser.UserId!.Value)
.Returns(user);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key));
Assert.Equal("Cannot reset password of a user with Key Connector.", exception.Message);
}
private static void SetupValidOrganization(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization)
{
organization.UseResetPassword = true;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
}
private static void SetupValidPolicy(SutProvider<AdminRecoverAccountCommand> sutProvider, Organization organization)
{
var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true };
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword)
.Returns(policy);
}
private static void SetupValidOrganizationUser(OrganizationUser organizationUser, Guid orgId)
{
organizationUser.Status = OrganizationUserStatusType.Confirmed;
organizationUser.OrganizationId = orgId;
organizationUser.ResetPasswordKey = "test-key";
organizationUser.Type = OrganizationUserType.User;
}
private static void SetupValidUser(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, OrganizationUser organizationUser)
{
user.Id = organizationUser.UserId!.Value;
user.UsesKeyConnector = false;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(user.Id)
.Returns(user);
}
private static void SetupSuccessfulPasswordUpdate(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, string newMasterPassword)
{
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(user, newMasterPassword)
.Returns(IdentityResult.Success);
}
private static async Task AssertSuccessAsync(SutProvider<AdminRecoverAccountCommand> sutProvider, User user, string key,
Organization organization, OrganizationUser organizationUser)
{
await sutProvider.GetDependency<IUserRepository>().Received(1).ReplaceAsync(
Arg.Is<User>(u =>
u.Id == user.Id &&
u.Key == key &&
u.ForcePasswordReset == true &&
u.RevisionDate == u.AccountRevisionDate &&
u.LastPasswordChangeDate == u.RevisionDate));
await sutProvider.GetDependency<IMailService>().Received(1).SendAdminResetPasswordEmailAsync(
Arg.Is(user.Email),
Arg.Is(user.Name),
Arg.Is(organization.DisplayName()));
await sutProvider.GetDependency<IEventService>().Received(1).LogOrganizationUserEventAsync(
Arg.Is(organizationUser),
Arg.Is(EventType.OrganizationUser_AdminResetPassword));
await sutProvider.GetDependency<IPushNotificationService>().Received(1).PushLogOutAsync(
Arg.Is(user.Id));
}
}

View File

@@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -191,6 +192,37 @@ public class VerifyOrganizationDomainCommandTests
x.PerformedBy.UserId == userId));
}
[Theory, BitAutoData]
public async Task UserVerifyOrganizationDomainAsync_WhenPolicyValidatorsRefactorFlagEnabled_UsesVNextSavePolicyCommand(
OrganizationDomain domain, Guid userId, SutProvider<VerifyOrganizationDomainCommand> sutProvider)
{
sutProvider.GetDependency<IOrganizationDomainRepository>()
.GetClaimedDomainsByDomainNameAsync(domain.DomainName)
.Returns([]);
sutProvider.GetDependency<IDnsResolverService>()
.ResolveAsync(domain.DomainName, domain.Txt)
.Returns(true);
sutProvider.GetDependency<ICurrentContext>()
.UserId.Returns(userId);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
_ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain);
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.SingleOrg &&
m.PolicyUpdate.OrganizationId == domain.OrganizationId &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is StandardUser &&
m.PerformedBy.UserId == userId));
}
[Theory, BitAutoData]
public async Task UserVerifyOrganizationDomainAsync_WhenDomainIsNotVerified_ThenSingleOrgPolicyShouldNotBeEnabled(
OrganizationDomain domain, SutProvider<VerifyOrganizationDomainCommand> sutProvider)

View File

@@ -97,6 +97,8 @@ public class ConfirmOrganizationUserCommandTests
[BitAutoData(PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)]
[BitAutoData(PlanType.FamiliesAnnually, OrganizationUserType.Admin)]
[BitAutoData(PlanType.FamiliesAnnually, OrganizationUserType.Owner)]
[BitAutoData(PlanType.FamiliesAnnually2025, OrganizationUserType.Admin)]
[BitAutoData(PlanType.FamiliesAnnually2025, OrganizationUserType.Owner)]
[BitAutoData(PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)]
[BitAutoData(PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)]
[BitAutoData(PlanType.TeamsAnnually, OrganizationUserType.Admin)]

View File

@@ -23,6 +23,7 @@ public class CloudICloudOrganizationSignUpCommandTests
{
[Theory]
[BitAutoData(PlanType.FamiliesAnnually)]
[BitAutoData(PlanType.FamiliesAnnually2025)]
public async Task SignUp_PM_Family_Passes(PlanType planType, OrganizationSignup signup, SutProvider<CloudOrganizationSignUpCommand> sutProvider)
{
signup.Plan = planType;
@@ -65,6 +66,7 @@ public class CloudICloudOrganizationSignUpCommandTests
[Theory]
[BitAutoData(PlanType.FamiliesAnnually)]
[BitAutoData(PlanType.FamiliesAnnually2025)]
public async Task SignUp_AssignsOwnerToDefaultCollection
(PlanType planType, OrganizationSignup signup, SutProvider<CloudOrganizationSignUpCommand> sutProvider)
{

View File

@@ -0,0 +1,628 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Test.AdminConsole.AutoFixture;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
[SutProviderCustomize]
public class AutomaticUserConfirmationPolicyEventHandlerTests
{
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_SingleOrgNotEnabled_ReturnsError(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns((Policy?)null);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("Single organization policy must be enabled", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_SingleOrgPolicyDisabled_ReturnsError(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg, false)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("Single organization policy must be enabled", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_UsersNotCompliantWithSingleOrg_ReturnsError(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
Guid nonCompliantUserId,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var orgUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Confirmed,
UserId = nonCompliantUserId,
Email = "user@example.com"
};
var otherOrgUser = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = Guid.NewGuid(),
UserId = nonCompliantUserId,
Status = OrganizationUserStatusType.Confirmed
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([orgUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([otherOrgUser]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_UserWithInvitedStatusInOtherOrg_ValidationPasses(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
Guid userId,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var orgUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Confirmed,
UserId = userId,
Email = "test@email.com"
};
var otherOrgUser = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = Guid.NewGuid(),
UserId = null, // invited users do not have a user id
Status = OrganizationUserStatusType.Invited,
Email = orgUser.Email
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([orgUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([otherOrgUser]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_ProviderUsersExist_ReturnsError(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var providerUser = new ProviderUser
{
Id = Guid.NewGuid(),
ProviderId = Guid.NewGuid(),
UserId = Guid.NewGuid(),
Status = ProviderUserStatusType.Confirmed
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([providerUser]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("Provider user type", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_AllValidationsPassed_ReturnsEmptyString(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var orgUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Confirmed,
UserId = Guid.NewGuid(),
Email = "user@example.com"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([orgUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_PolicyAlreadyEnabled_ReturnsEmptyString(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, currentPolicy);
// Assert
Assert.True(string.IsNullOrEmpty(result));
await sutProvider.GetDependency<IPolicyRepository>()
.DidNotReceive()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>());
}
[Theory, BitAutoData]
public async Task ValidateAsync_DisablingPolicy_ReturnsEmptyString(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation, false)] PolicyUpdate policyUpdate,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, currentPolicy);
// Assert
Assert.True(string.IsNullOrEmpty(result));
await sutProvider.GetDependency<IPolicyRepository>()
.DidNotReceive()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), Arg.Any<PolicyType>());
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_IncludesOwnersAndAdmins_InComplianceCheck(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
Guid nonCompliantOwnerId,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var ownerUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.Owner,
Status = OrganizationUserStatusType.Confirmed,
UserId = nonCompliantOwnerId,
Email = "owner@example.com"
};
var otherOrgUser = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = Guid.NewGuid(),
UserId = nonCompliantOwnerId,
Status = OrganizationUserStatusType.Confirmed
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([ownerUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([otherOrgUser]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_InvitedUsersExcluded_FromComplianceCheck(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var invitedUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Invited,
UserId = Guid.NewGuid(),
Email = "invited@example.com"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([invitedUser]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_RevokedUsersExcluded_FromComplianceCheck(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var revokedUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Revoked,
UserId = Guid.NewGuid(),
Email = "revoked@example.com"
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([revokedUser]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_AcceptedUsersIncluded_InComplianceCheck(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
Guid nonCompliantUserId,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var acceptedUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Accepted,
UserId = nonCompliantUserId,
Email = "accepted@example.com"
};
var otherOrgUser = new OrganizationUser
{
Id = Guid.NewGuid(),
OrganizationId = Guid.NewGuid(),
UserId = nonCompliantUserId,
Status = OrganizationUserStatusType.Confirmed
};
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([acceptedUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([otherOrgUser]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase);
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_EmptyOrganization_ReturnsEmptyString(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_WithSavePolicyModel_CallsValidateWithPolicyUpdate(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId;
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg)
.Returns(singleOrgPolicy);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task OnSaveSideEffectsAsync_EnablingPolicy_SetsUseAutomaticUserConfirmationToTrue(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
Organization organization,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
organization.Id = policyUpdate.OrganizationId;
organization.UseAutomaticUserConfirmation = false;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(policyUpdate.OrganizationId)
.Returns(organization);
// Act
await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null);
// Assert
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.UpsertAsync(Arg.Is<Organization>(o =>
o.Id == organization.Id &&
o.UseAutomaticUserConfirmation == true &&
o.RevisionDate > DateTime.MinValue));
}
[Theory, BitAutoData]
public async Task OnSaveSideEffectsAsync_DisablingPolicy_SetsUseAutomaticUserConfirmationToFalse(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation, false)] PolicyUpdate policyUpdate,
Organization organization,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
organization.Id = policyUpdate.OrganizationId;
organization.UseAutomaticUserConfirmation = true;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(policyUpdate.OrganizationId)
.Returns(organization);
// Act
await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null);
// Assert
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.UpsertAsync(Arg.Is<Organization>(o =>
o.Id == organization.Id &&
o.UseAutomaticUserConfirmation == false &&
o.RevisionDate > DateTime.MinValue));
}
[Theory, BitAutoData]
public async Task OnSaveSideEffectsAsync_OrganizationNotFound_DoesNotThrowException(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(policyUpdate.OrganizationId)
.Returns((Organization?)null);
// Act
await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null);
// Assert
await sutProvider.GetDependency<IOrganizationRepository>()
.DidNotReceive()
.UpsertAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task ExecutePreUpsertSideEffectAsync_CallsOnSaveSideEffectsAsync(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
[Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy,
Organization organization,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
organization.Id = policyUpdate.OrganizationId;
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(policyUpdate.OrganizationId)
.Returns(organization);
// Act
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, currentPolicy);
// Assert
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.UpsertAsync(Arg.Is<Organization>(o =>
o.Id == organization.Id &&
o.UseAutomaticUserConfirmation == policyUpdate.Enabled));
}
[Theory, BitAutoData]
public async Task OnSaveSideEffectsAsync_UpdatesRevisionDate(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
Organization organization,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
organization.Id = policyUpdate.OrganizationId;
var originalRevisionDate = DateTime.UtcNow.AddDays(-1);
organization.RevisionDate = originalRevisionDate;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(policyUpdate.OrganizationId)
.Returns(organization);
// Act
await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null);
// Assert
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.UpsertAsync(Arg.Is<Organization>(o =>
o.Id == organization.Id &&
o.RevisionDate > originalRevisionDate));
}
}

View File

@@ -92,7 +92,7 @@ public class FreeFamiliesForEnterprisePolicyValidatorTests
.GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId)
.Returns(organizationSponsorships);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);
@@ -120,7 +120,7 @@ public class FreeFamiliesForEnterprisePolicyValidatorTests
.GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId)
.Returns(organizationSponsorships);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -32,7 +32,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(false);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -58,7 +58,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -84,7 +84,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -110,7 +110,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -199,7 +199,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -238,7 +238,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, metadata);
var policyRequest = new SavePolicyModel(policyUpdate, metadata);
// Act
await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -286,7 +286,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(false);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -312,7 +312,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -338,7 +338,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -364,7 +364,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -404,7 +404,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var collectionRepository = Substitute.For<ICollectionRepository>();
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
@@ -436,7 +436,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)
.Returns(true);
var policyRequest = new SavePolicyModel(policyUpdate, null, metadata);
var policyRequest = new SavePolicyModel(policyUpdate, metadata);
// Act
await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState);

View File

@@ -88,7 +88,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
@@ -109,7 +109,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Trusted device encryption is on", result, StringComparison.OrdinalIgnoreCase);
@@ -129,7 +129,7 @@ public class RequireSsoPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));

View File

@@ -94,7 +94,7 @@ public class ResetPasswordPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Trusted device encryption is on and requires this policy.", result, StringComparison.OrdinalIgnoreCase);
@@ -118,7 +118,7 @@ public class ResetPasswordPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));

View File

@@ -162,7 +162,7 @@ public class SingleOrgPolicyValidatorTests
.GetByOrganizationIdAsync(policyUpdate.OrganizationId)
.Returns(ssoConfig);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase);
@@ -186,7 +186,7 @@ public class SingleOrgPolicyValidatorTests
.HasVerifiedDomainsAsync(policyUpdate.OrganizationId)
.Returns(false);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy);
Assert.True(string.IsNullOrEmpty(result));
@@ -256,7 +256,7 @@ public class SingleOrgPolicyValidatorTests
.RevokeNonCompliantOrganizationUsersAsync(Arg.Any<RevokeOrganizationUsersRequest>())
.Returns(new CommandResult());
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -169,7 +169,7 @@ public class TwoFactorAuthenticationPolicyValidatorTests
(orgUserDetailUserWithout2Fa, false),
});
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy));
@@ -228,7 +228,7 @@ public class TwoFactorAuthenticationPolicyValidatorTests
.RevokeNonCompliantOrganizationUsersAsync(Arg.Any<RevokeOrganizationUsersRequest>())
.Returns(new CommandResult());
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
// Act
await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy);

View File

@@ -0,0 +1,28 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators;
public class UriMatchDefaultPolicyValidatorTests
{
private readonly UriMatchDefaultPolicyValidator _validator = new();
[Fact]
// Test that the Type property returns the correct PolicyType for this validator
public void Type_ReturnsUriMatchDefaults()
{
Assert.Equal(PolicyType.UriMatchDefaults, _validator.Type);
}
[Fact]
// Test that the RequiredPolicies property returns exactly one policy (SingleOrg) as a prerequisite
// for enabling the UriMatchDefaults policy, ensuring proper policy dependency enforcement
public void RequiredPolicies_ReturnsSingleOrgPolicy()
{
var requiredPolicies = _validator.RequiredPolicies.ToList();
Assert.Single(requiredPolicies);
Assert.Contains(PolicyType.SingleOrg, requiredPolicies);
}
}

View File

@@ -288,7 +288,7 @@ public class SavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
@@ -332,7 +332,7 @@ public class SavePolicyCommandTests
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type)

View File

@@ -33,7 +33,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var newPolicy = new Policy
{
@@ -77,7 +77,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
currentPolicy.OrganizationId = policyUpdate.OrganizationId;
sutProvider.GetDependency<IPolicyRepository>()
@@ -117,7 +117,7 @@ public class VNextSavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IApplicationCacheService>()
.GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
@@ -137,7 +137,7 @@ public class VNextSavePolicyCommandTests
{
// Arrange
var sutProvider = SutProviderFactory();
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
sutProvider.GetDependency<IApplicationCacheService>()
.GetOrganizationAbilityAsync(policyUpdate.OrganizationId)
@@ -167,7 +167,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -202,7 +202,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -237,7 +237,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var requireSsoPolicy = new Policy
{
@@ -271,7 +271,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -302,7 +302,7 @@ public class VNextSavePolicyCommandTests
new FakeVaultTimeoutDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -331,7 +331,7 @@ public class VNextSavePolicyCommandTests
new FakeSingleOrgDependencyEvent()
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
ArrangeOrganization(sutProvider, policyUpdate);
sutProvider.GetDependency<IPolicyRepository>()
@@ -356,7 +356,7 @@ public class VNextSavePolicyCommandTests
fakePolicyValidationEvent
]);
var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel());
var savePolicyModel = new SavePolicyModel(policyUpdate);
var singleOrgPolicy = new Policy
{

View File

@@ -38,6 +38,20 @@ public class EventIntegrationEventWriteServiceTests
organizationId: Arg.Is<string>(orgId => eventMessage.OrganizationId.ToString().Equals(orgId)));
}
[Fact]
public async Task CreateManyAsync_EmptyList_DoesNothing()
{
await Subject.CreateManyAsync([]);
await _eventIntegrationPublisher.DidNotReceiveWithAnyArgs().PublishEventAsync(Arg.Any<string>(), Arg.Any<string>());
}
[Fact]
public async Task DisposeAsync_DisposesEventIntegrationPublisher()
{
await Subject.DisposeAsync();
await _eventIntegrationPublisher.Received(1).DisposeAsync();
}
private static bool AssertJsonStringsMatch(EventMessage expected, string body)
{
var actual = JsonSerializer.Deserialize<EventMessage>(body);

View File

@@ -120,6 +120,16 @@ public class EventIntegrationHandlerTests
Assert.Empty(_eventIntegrationPublisher.ReceivedCalls());
}
[Theory, BitAutoData]
public async Task HandleEventAsync_NoOrganizationId_DoesNothing(EventMessage eventMessage)
{
var sutProvider = GetSutProvider(OneConfiguration(_templateBase));
eventMessage.OrganizationId = null;
await sutProvider.Sut.HandleEventAsync(eventMessage);
Assert.Empty(_eventIntegrationPublisher.ReceivedCalls());
}
[Theory, BitAutoData]
public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage)
{

View File

@@ -1,65 +0,0 @@
using Bit.Core.Models.Data;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class EventRouteServiceTests
{
private readonly IEventWriteService _broadcastEventWriteService = Substitute.For<IEventWriteService>();
private readonly IEventWriteService _storageEventWriteService = Substitute.For<IEventWriteService>();
private readonly IFeatureService _featureService = Substitute.For<IFeatureService>();
private readonly EventRouteService Subject;
public EventRouteServiceTests()
{
Subject = new EventRouteService(_broadcastEventWriteService, _storageEventWriteService, _featureService);
}
[Theory, BitAutoData]
public async Task CreateAsync_FlagDisabled_EventSentToStorageService(EventMessage eventMessage)
{
_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(false);
await Subject.CreateAsync(eventMessage);
await _broadcastEventWriteService.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any<EventMessage>());
await _storageEventWriteService.Received(1).CreateAsync(eventMessage);
}
[Theory, BitAutoData]
public async Task CreateAsync_FlagEnabled_EventSentToBroadcastService(EventMessage eventMessage)
{
_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(true);
await Subject.CreateAsync(eventMessage);
await _broadcastEventWriteService.Received(1).CreateAsync(eventMessage);
await _storageEventWriteService.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any<EventMessage>());
}
[Theory, BitAutoData]
public async Task CreateManyAsync_FlagDisabled_EventsSentToStorageService(IEnumerable<EventMessage> eventMessages)
{
_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(false);
await Subject.CreateManyAsync(eventMessages);
await _broadcastEventWriteService.DidNotReceiveWithAnyArgs().CreateManyAsync(Arg.Any<IEnumerable<EventMessage>>());
await _storageEventWriteService.Received(1).CreateManyAsync(eventMessages);
}
[Theory, BitAutoData]
public async Task CreateManyAsync_FlagEnabled_EventsSentToBroadcastService(IEnumerable<EventMessage> eventMessages)
{
_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(true);
await Subject.CreateManyAsync(eventMessages);
await _broadcastEventWriteService.Received(1).CreateManyAsync(eventMessages);
await _storageEventWriteService.DidNotReceiveWithAnyArgs().CreateManyAsync(Arg.Any<IEnumerable<EventMessage>>());
}
}

View File

@@ -42,6 +42,35 @@ public class IntegrationFilterServiceTests
Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage));
}
[Theory, BitAutoData]
public void EvaluateFilterGroup_EqualsUserIdString_Matches(EventMessage eventMessage)
{
var userId = Guid.NewGuid();
eventMessage.UserId = userId;
var group = new IntegrationFilterGroup
{
AndOperator = true,
Rules =
[
new()
{
Property = "UserId",
Operation = IntegrationFilterOperation.Equals,
Value = userId.ToString()
}
]
};
var result = _service.EvaluateFilterGroup(group, eventMessage);
Assert.True(result);
var jsonGroup = JsonSerializer.Serialize(group);
var roundtrippedGroup = JsonSerializer.Deserialize<IntegrationFilterGroup>(jsonGroup);
Assert.NotNull(roundtrippedGroup);
Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage));
}
[Theory, BitAutoData]
public void EvaluateFilterGroup_EqualsUserId_DoesNotMatch(EventMessage eventMessage)
{
@@ -281,6 +310,45 @@ public class IntegrationFilterServiceTests
Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage));
}
[Theory, BitAutoData]
public void EvaluateFilterGroup_NestedGroups_AnyMatch(EventMessage eventMessage)
{
var id = Guid.NewGuid();
var collectionId = Guid.NewGuid();
eventMessage.UserId = id;
eventMessage.CollectionId = collectionId;
var nestedGroup = new IntegrationFilterGroup
{
AndOperator = false,
Rules =
[
new() { Property = "UserId", Operation = IntegrationFilterOperation.Equals, Value = id },
new()
{
Property = "CollectionId",
Operation = IntegrationFilterOperation.In,
Value = new Guid?[] { Guid.NewGuid() }
}
]
};
var topGroup = new IntegrationFilterGroup
{
AndOperator = false,
Groups = [nestedGroup]
};
var result = _service.EvaluateFilterGroup(topGroup, eventMessage);
Assert.True(result);
var jsonGroup = JsonSerializer.Serialize(topGroup);
var roundtrippedGroup = JsonSerializer.Deserialize<IntegrationFilterGroup>(jsonGroup);
Assert.NotNull(roundtrippedGroup);
Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage));
}
[Theory, BitAutoData]
public void EvaluateFilterGroup_UnknownProperty_ReturnsFalse(EventMessage eventMessage)
{

View File

@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Models.Data.EventIntegrations;
using Bit.Core.Models.Slack;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
@@ -28,6 +29,9 @@ public class SlackIntegrationHandlerTests
var sutProvider = GetSutProvider();
message.Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token);
_slackService.SendSlackMessageByChannelIdAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
.Returns(new SlackSendMessageResponse() { Ok = true, Channel = _channelId });
var result = await sutProvider.Sut.HandleAsync(message);
Assert.True(result.Success);
@@ -39,4 +43,97 @@ public class SlackIntegrationHandlerTests
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId))
);
}
[Theory]
[InlineData("service_unavailable")]
[InlineData("ratelimited")]
[InlineData("rate_limited")]
[InlineData("internal_error")]
[InlineData("message_limit_exceeded")]
public async Task HandleAsync_FailedRetryableRequest_ReturnsFailureWithRetryable(string error)
{
var sutProvider = GetSutProvider();
var message = new IntegrationMessage<SlackIntegrationConfigurationDetails>()
{
Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token),
MessageId = "MessageId",
RenderedTemplate = "Test Message"
};
_slackService.SendSlackMessageByChannelIdAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
.Returns(new SlackSendMessageResponse() { Ok = false, Channel = _channelId, Error = error });
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.True(result.Retryable);
Assert.NotNull(result.FailureReason);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ISlackService>().Received(1).SendSlackMessageByChannelIdAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_token)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId))
);
}
[Theory]
[InlineData("access_denied")]
[InlineData("channel_not_found")]
[InlineData("token_expired")]
[InlineData("token_revoked")]
public async Task HandleAsync_FailedNonRetryableRequest_ReturnsNonRetryableFailure(string error)
{
var sutProvider = GetSutProvider();
var message = new IntegrationMessage<SlackIntegrationConfigurationDetails>()
{
Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token),
MessageId = "MessageId",
RenderedTemplate = "Test Message"
};
_slackService.SendSlackMessageByChannelIdAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
.Returns(new SlackSendMessageResponse() { Ok = false, Channel = _channelId, Error = error });
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.False(result.Retryable);
Assert.NotNull(result.FailureReason);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ISlackService>().Received(1).SendSlackMessageByChannelIdAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_token)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId))
);
}
[Fact]
public async Task HandleAsync_NullResponse_ReturnsNonRetryableFailure()
{
var sutProvider = GetSutProvider();
var message = new IntegrationMessage<SlackIntegrationConfigurationDetails>()
{
Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token),
MessageId = "MessageId",
RenderedTemplate = "Test Message"
};
_slackService.SendSlackMessageByChannelIdAsync(Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
.Returns((SlackSendMessageResponse?)null);
var result = await sutProvider.Sut.HandleAsync(message);
Assert.False(result.Success);
Assert.False(result.Retryable);
Assert.Equal("Slack response was null", result.FailureReason);
Assert.Equal(result.Message, message);
await sutProvider.GetDependency<ISlackService>().Received(1).SendSlackMessageByChannelIdAsync(
Arg.Is(AssertHelper.AssertPropertyEqual(_token)),
Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)),
Arg.Is(AssertHelper.AssertPropertyEqual(_channelId))
);
}
}

View File

@@ -146,6 +146,27 @@ public class SlackServiceTests
Assert.Empty(result);
}
[Fact]
public async Task GetChannelIdAsync_NoChannelFound_ReturnsEmptyResult()
{
var emptyResponse = JsonSerializer.Serialize(
new
{
ok = true,
channels = Array.Empty<string>(),
response_metadata = new { next_cursor = "" }
});
_handler.When(HttpMethod.Get)
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent(emptyResponse));
var sutProvider = GetSutProvider();
var result = await sutProvider.Sut.GetChannelIdAsync(_token, "general");
Assert.Empty(result);
}
[Fact]
public async Task GetChannelIdAsync_ReturnsCorrectChannelId()
{
@@ -235,6 +256,32 @@ public class SlackServiceTests
Assert.Equal(string.Empty, result);
}
[Fact]
public async Task GetDmChannelByEmailAsync_ApiErrorUnparsableDmResponse_ReturnsEmptyString()
{
var sutProvider = GetSutProvider();
var email = "user@example.com";
var userId = "U12345";
var userResponse = new
{
ok = true,
user = new { id = userId }
};
_handler.When($"https://slack.com/api/users.lookupByEmail?email={email}")
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent(JsonSerializer.Serialize(userResponse)));
_handler.When("https://slack.com/api/conversations.open")
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent("NOT JSON"));
var result = await sutProvider.Sut.GetDmChannelByEmailAsync(_token, email);
Assert.Equal(string.Empty, result);
}
[Fact]
public async Task GetDmChannelByEmailAsync_ApiErrorUserResponse_ReturnsEmptyString()
{
@@ -244,7 +291,7 @@ public class SlackServiceTests
var userResponse = new
{
ok = false,
error = "An error occured"
error = "An error occurred"
};
_handler.When($"https://slack.com/api/users.lookupByEmail?email={email}")
@@ -256,6 +303,21 @@ public class SlackServiceTests
Assert.Equal(string.Empty, result);
}
[Fact]
public async Task GetDmChannelByEmailAsync_ApiErrorUnparsableUserResponse_ReturnsEmptyString()
{
var sutProvider = GetSutProvider();
var email = "user@example.com";
_handler.When($"https://slack.com/api/users.lookupByEmail?email={email}")
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent("Not JSON"));
var result = await sutProvider.Sut.GetDmChannelByEmailAsync(_token, email);
Assert.Equal(string.Empty, result);
}
[Fact]
public void GetRedirectUrl_ReturnsCorrectUrl()
{
@@ -341,18 +403,29 @@ public class SlackServiceTests
}
[Fact]
public async Task SendSlackMessageByChannelId_Sends_Correct_Message()
public async Task SendSlackMessageByChannelId_Success_ReturnsSuccessfulResponse()
{
var sutProvider = GetSutProvider();
var channelId = "C12345";
var message = "Hello, Slack!";
var jsonResponse = JsonSerializer.Serialize(new
{
ok = true,
channel = channelId,
});
_handler.When(HttpMethod.Post)
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent(string.Empty));
.WithContent(new StringContent(jsonResponse));
await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId);
var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId);
// Response was parsed correctly
Assert.NotNull(result);
Assert.True(result.Ok);
// Request was sent correctly
Assert.Single(_handler.CapturedRequests);
var request = _handler.CapturedRequests[0];
Assert.NotNull(request);
@@ -365,4 +438,62 @@ public class SlackServiceTests
Assert.Equal(message, json.RootElement.GetProperty("text").GetString() ?? string.Empty);
Assert.Equal(channelId, json.RootElement.GetProperty("channel").GetString() ?? string.Empty);
}
[Fact]
public async Task SendSlackMessageByChannelId_Failure_ReturnsErrorResponse()
{
var sutProvider = GetSutProvider();
var channelId = "C12345";
var message = "Hello, Slack!";
var jsonResponse = JsonSerializer.Serialize(new
{
ok = false,
channel = channelId,
error = "error"
});
_handler.When(HttpMethod.Post)
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent(jsonResponse));
var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId);
// Response was parsed correctly
Assert.NotNull(result);
Assert.False(result.Ok);
Assert.NotNull(result.Error);
}
[Fact]
public async Task SendSlackMessageByChannelIdAsync_InvalidJson_ReturnsNull()
{
var sutProvider = GetSutProvider();
var channelId = "C12345";
var message = "Hello world!";
_handler.When(HttpMethod.Post)
.RespondWith(HttpStatusCode.OK)
.WithContent(new StringContent("Not JSON"));
var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId);
Assert.Null(result);
}
[Fact]
public async Task SendSlackMessageByChannelIdAsync_HttpServerError_ReturnsNull()
{
var sutProvider = GetSutProvider();
var channelId = "C12345";
var message = "Hello world!";
_handler.When(HttpMethod.Post)
.RespondWith(HttpStatusCode.InternalServerError)
.WithContent(new StringContent(string.Empty));
var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId);
Assert.Null(result);
}
}

View File

@@ -0,0 +1,59 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.Utilities;
using Bit.Core.Exceptions;
using Xunit;
namespace Bit.Core.Test.AdminConsole.Utilities;
public class PolicyDataValidatorTests
{
[Fact]
public void ValidateAndSerialize_NullData_ReturnsNull()
{
var result = PolicyDataValidator.ValidateAndSerialize(null, PolicyType.MasterPassword);
Assert.Null(result);
}
[Fact]
public void ValidateAndSerialize_ValidData_ReturnsSerializedJson()
{
var data = new Dictionary<string, object> { { "minLength", 12 } };
var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword);
Assert.NotNull(result);
Assert.Contains("\"minLength\":12", result);
}
[Fact]
public void ValidateAndSerialize_InvalidDataType_ThrowsBadRequestException()
{
var data = new Dictionary<string, object> { { "minLength", "not a number" } };
var exception = Assert.Throws<BadRequestException>(() =>
PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword));
Assert.Contains("Invalid data for MasterPassword policy", exception.Message);
Assert.Contains("minLength", exception.Message);
}
[Fact]
public void ValidateAndDeserializeMetadata_NullMetadata_ReturnsEmptyMetadataModel()
{
var result = PolicyDataValidator.ValidateAndDeserializeMetadata(null, PolicyType.SingleOrg);
Assert.IsType<EmptyMetadataModel>(result);
}
[Fact]
public void ValidateAndDeserializeMetadata_ValidMetadata_ReturnsModel()
{
var metadata = new Dictionary<string, object> { { "defaultUserCollectionName", "collection name" } };
var result = PolicyDataValidator.ValidateAndDeserializeMetadata(metadata, PolicyType.OrganizationDataOwnership);
Assert.IsType<OrganizationModelOwnershipPolicyModel>(result);
}
}

View File

@@ -1,8 +1,10 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
@@ -12,6 +14,7 @@ using Bit.Core.Auth.Services;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
@@ -364,4 +367,54 @@ public class SsoConfigServiceTests
await sutProvider.GetDependency<ISsoConfigRepository>().ReceivedWithAnyArgs()
.UpsertAsync(default);
}
[Theory, BitAutoData]
public async Task SaveAsync_Tde_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand(
SutProvider<SsoConfigService> sutProvider, Organization organization)
{
var ssoConfig = new SsoConfig
{
Id = default,
Data = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption,
}.Serialize(),
Enabled = true,
OrganizationId = organization.Id,
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)
.Returns(true);
await sutProvider.Sut.SaveAsync(ssoConfig, organization);
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.SingleOrg &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.ResetPassword &&
m.PolicyUpdate.GetDataModel<ResetPasswordDataModel>().AutoEnrollEnabled &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<IVNextSavePolicyCommand>()
.Received(1)
.SaveAsync(Arg.Is<SavePolicyModel>(m =>
m.PolicyUpdate.Type == PolicyType.RequireSso &&
m.PolicyUpdate.OrganizationId == organization.Id &&
m.PolicyUpdate.Enabled &&
m.PerformedBy is SystemUser));
await sutProvider.GetDependency<ISsoConfigRepository>().ReceivedWithAnyArgs()
.UpsertAsync(default);
}
}

View File

@@ -7,6 +7,7 @@ using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.UserFeatures.Registration.Implementations;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
@@ -80,6 +81,120 @@ public class RegisterUserCommandTests
.SendWelcomeEmailAsync(Arg.Any<User>());
}
// -----------------------------------------------------------------------------------------------
// RegisterSSOAutoProvisionedUserAsync tests
// -----------------------------------------------------------------------------------------------
[Theory, BitAutoData]
public async Task RegisterSSOAutoProvisionedUserAsync_Success(
User user,
Organization organization,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
user.Id = Guid.NewGuid();
organization.Id = Guid.NewGuid();
organization.Name = "Test Organization";
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
// Act
var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
Assert.True(result.Succeeded);
await sutProvider.GetDependency<IUserService>()
.Received(1)
.CreateUserAsync(user);
}
[Theory, BitAutoData]
public async Task RegisterSSOAutoProvisionedUserAsync_UserRegistrationFails_ReturnsFailedResult(
User user,
Organization organization,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
var expectedError = new IdentityError();
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Failed(expectedError));
// Act
var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
Assert.False(result.Succeeded);
Assert.Contains(expectedError, result.Errors);
await sutProvider.GetDependency<IMailService>()
.DidNotReceive()
.SendOrganizationUserWelcomeEmailAsync(Arg.Any<User>(), Arg.Any<string>());
}
[Theory]
[BitAutoData(PlanType.EnterpriseAnnually)]
[BitAutoData(PlanType.EnterpriseMonthly)]
[BitAutoData(PlanType.TeamsAnnually)]
public async Task RegisterSSOAutoProvisionedUserAsync_EnterpriseOrg_SendsOrganizationWelcomeEmail(
PlanType planType,
User user,
Organization organization,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
organization.PlanType = planType;
organization.Name = "Enterprise Org";
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((OrganizationUser)null);
// Act
await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendOrganizationUserWelcomeEmailAsync(user, organization.Name);
}
[Theory, BitAutoData]
public async Task RegisterSSOAutoProvisionedUserAsync_FeatureFlagDisabled_SendsLegacyWelcomeEmail(
User user,
Organization organization,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(false);
// Act
await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendWelcomeEmailAsync(user);
}
// -----------------------------------------------------------------------------------------------
// RegisterUserWithOrganizationInviteToken tests
// -----------------------------------------------------------------------------------------------
@@ -646,5 +761,186 @@ public class RegisterUserCommandTests
Assert.Equal("Open registration has been disabled by the system administrator.", result.Message);
}
// -----------------------------------------------------------------------------------------------
// SendWelcomeEmail tests
// -----------------------------------------------------------------------------------------------
[Theory]
[BitAutoData(PlanType.FamiliesAnnually)]
[BitAutoData(PlanType.FamiliesAnnually2019)]
[BitAutoData(PlanType.Free)]
public async Task SendWelcomeEmail_FamilyOrg_SendsFamilyWelcomeEmail(
PlanType planType,
User user,
Organization organization,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
organization.PlanType = planType;
organization.Name = "Family Org";
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((OrganizationUser)null);
// Act
await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, organization.Name);
}
[Theory]
[BitAutoData]
public async Task SendWelcomeEmail_OrganizationNull_SendsIndividualWelcomeEmail(
User user,
OrganizationUser orgUser,
string orgInviteToken,
string masterPasswordHash,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
user.ReferenceData = null;
orgUser.Email = user.Email;
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user, masterPasswordHash)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByIdAsync(orgUser.Id)
.Returns(orgUser);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns((Policy)null);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgUser.OrganizationId)
.Returns((Organization)null);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
var orgInviteTokenable = new OrgUserInviteTokenable(orgUser);
sutProvider.GetDependency<IDataProtectorTokenFactory<OrgUserInviteTokenable>>()
.TryUnprotect(orgInviteToken, out Arg.Any<OrgUserInviteTokenable>())
.Returns(callInfo =>
{
callInfo[1] = orgInviteTokenable;
return true;
});
// Act
var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id);
// Assert
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendIndividualUserWelcomeEmailAsync(user);
}
[Theory]
[BitAutoData]
public async Task SendWelcomeEmail_OrganizationDisplayNameNull_SendsIndividualWelcomeEmail(
User user,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
Organization organization = new Organization
{
Name = null
};
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByIdAsync(Arg.Any<Guid>())
.Returns((OrganizationUser)null);
// Act
await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization);
// Assert
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendIndividualUserWelcomeEmailAsync(user);
}
[Theory]
[BitAutoData]
public async Task GetOrganizationWelcomeEmailDetailsAsync_HappyPath_ReturnsOrganizationWelcomeEmailDetails(
Organization organization,
User user,
OrganizationUser orgUser,
string masterPasswordHash,
string orgInviteToken,
SutProvider<RegisterUserCommand> sutProvider)
{
// Arrange
user.ReferenceData = null;
orgUser.Email = user.Email;
organization.PlanType = PlanType.EnterpriseAnnually;
sutProvider.GetDependency<IUserService>()
.CreateUserAsync(user, masterPasswordHash)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByIdAsync(orgUser.Id)
.Returns(orgUser);
sutProvider.GetDependency<IPolicyRepository>()
.GetByOrganizationIdTypeAsync(Arg.Any<Guid>(), PolicyType.TwoFactorAuthentication)
.Returns((Policy)null);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(orgUser.OrganizationId)
.Returns(organization);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)
.Returns(true);
var orgInviteTokenable = new OrgUserInviteTokenable(orgUser);
sutProvider.GetDependency<IDataProtectorTokenFactory<OrgUserInviteTokenable>>()
.TryUnprotect(orgInviteToken, out Arg.Any<OrgUserInviteTokenable>())
.Returns(callInfo =>
{
callInfo[1] = orgInviteTokenable;
return true;
});
// Act
var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id);
// Assert
Assert.True(result.Succeeded);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.GetByIdAsync(orgUser.OrganizationId);
await sutProvider.GetDependency<IMailService>()
.Received(1)
.SendOrganizationUserWelcomeEmailAsync(user, organization.DisplayName());
}
}

View File

@@ -21,15 +21,6 @@ namespace Bit.Core.Test.Billing.Organizations.Queries;
[SutProviderCustomize]
public class GetOrganizationMetadataQueryTests
{
[Theory, BitAutoData]
public async Task Run_NullOrganization_ReturnsNull(
SutProvider<GetOrganizationMetadataQuery> sutProvider)
{
var result = await sutProvider.Sut.Run(null);
Assert.Null(result);
}
[Theory, BitAutoData]
public async Task Run_SelfHosted_ReturnsDefault(
Organization organization,
@@ -74,8 +65,7 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 5, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.ReturnsNull();
var result = await sutProvider.Sut.Run(organization);
@@ -100,12 +90,12 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 7, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
sutProvider.GetDependency<ISubscriberService>()
.GetSubscription(organization)
.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to")))
.ReturnsNull();
var result = await sutProvider.Sut.Run(organization);
@@ -124,23 +114,24 @@ public class GetOrganizationMetadataQueryTests
organization.PlanType = PlanType.EnterpriseAnnually;
var productId = "product_123";
var customer = new Customer
{
Discount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = [productId]
}
}
}
};
var customer = new Customer();
var subscription = new Subscription
{
Discounts =
[
new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = [productId]
}
}
}
],
Items = new StripeList<SubscriptionItem>
{
Data =
@@ -162,12 +153,12 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 15, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
sutProvider.GetDependency<ISubscriberService>()
.GetSubscription(organization)
.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to")))
.Returns(subscription);
sutProvider.GetDependency<IPricingClient>()
@@ -189,13 +180,11 @@ public class GetOrganizationMetadataQueryTests
organization.GatewaySubscriptionId = "sub_123";
organization.PlanType = PlanType.TeamsAnnually;
var customer = new Customer
{
Discount = null
};
var customer = new Customer();
var subscription = new Subscription
{
Discounts = null,
Items = new StripeList<SubscriptionItem>
{
Data =
@@ -217,12 +206,12 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 20, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
sutProvider.GetDependency<ISubscriberService>()
.GetSubscription(organization)
.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to")))
.Returns(subscription);
sutProvider.GetDependency<IPricingClient>()
@@ -244,23 +233,24 @@ public class GetOrganizationMetadataQueryTests
organization.GatewaySubscriptionId = "sub_123";
organization.PlanType = PlanType.EnterpriseAnnually;
var customer = new Customer
{
Discount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = ["different_product_id"]
}
}
}
};
var customer = new Customer();
var subscription = new Subscription
{
Discounts =
[
new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = ["different_product_id"]
}
}
}
],
Items = new StripeList<SubscriptionItem>
{
Data =
@@ -282,12 +272,12 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 12, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
sutProvider.GetDependency<ISubscriberService>()
.GetSubscription(organization)
.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to")))
.Returns(subscription);
sutProvider.GetDependency<IPricingClient>()
@@ -310,23 +300,24 @@ public class GetOrganizationMetadataQueryTests
organization.PlanType = PlanType.FamiliesAnnually;
var productId = "product_123";
var customer = new Customer
{
Discount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = [productId]
}
}
}
};
var customer = new Customer();
var subscription = new Subscription
{
Discounts =
[
new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = [productId]
}
}
}
],
Items = new StripeList<SubscriptionItem>
{
Data =
@@ -348,12 +339,12 @@ public class GetOrganizationMetadataQueryTests
.Returns(new OrganizationSeatCounts { Users = 8, Sponsored = 0 });
sutProvider.GetDependency<ISubscriberService>()
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
sutProvider.GetDependency<ISubscriberService>()
.GetSubscription(organization)
.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to")))
.Returns(subscription);
sutProvider.GetDependency<IPricingClient>()

View File

@@ -1,5 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Clients;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Entities;
@@ -11,12 +12,18 @@ using Invoice = BitPayLight.Models.Invoice.Invoice;
namespace Bit.Core.Test.Billing.Payment.Commands;
using static BitPayConstants;
public class CreateBitPayInvoiceForCreditCommandTests
{
private readonly IBitPayClient _bitPayClient = Substitute.For<IBitPayClient>();
private readonly GlobalSettings _globalSettings = new()
{
BitPay = new GlobalSettings.BitPaySettings { NotificationUrl = "https://example.com/bitpay/notification" }
BitPay = new GlobalSettings.BitPaySettings
{
NotificationUrl = "https://example.com/bitpay/notification",
WebhookKey = "test-webhook-key"
}
};
private const string _redirectUrl = "https://bitwarden.com/redirect";
private readonly CreateBitPayInvoiceForCreditCommand _command;
@@ -37,8 +44,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == user.Email &&
options.Buyer.Name == user.Email &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"userId:{user.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"userId:{user.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
@@ -58,8 +65,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == organization.BillingEmail &&
options.Buyer.Name == organization.Name &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"organizationId:{organization.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"organizationId:{organization.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
@@ -79,8 +86,8 @@ public class CreateBitPayInvoiceForCreditCommandTests
_bitPayClient.CreateInvoice(Arg.Is<Invoice>(options =>
options.Buyer.Email == provider.BillingEmail &&
options.Buyer.Name == provider.Name &&
options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
options.PosData == $"providerId:{provider.Id},accountCredit:1" &&
options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" &&
options.PosData == $"providerId:{provider.Id},{PosDataKeys.AccountCredit}" &&
// ReSharper disable once CompareOfFloatsByEqualityOperator
options.Price == Convert.ToDouble(10M) &&
options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });

View File

@@ -2,7 +2,9 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
@@ -34,6 +36,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPushNotificationService _pushNotificationService = Substitute.For<IPushNotificationService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IHasPaymentMethodQuery _hasPaymentMethodQuery = Substitute.For<IHasPaymentMethodQuery>();
private readonly IUpdatePaymentMethodCommand _updatePaymentMethodCommand = Substitute.For<IUpdatePaymentMethodCommand>();
private readonly CreatePremiumCloudHostedSubscriptionCommand _command;
public CreatePremiumCloudHostedSubscriptionCommandTests()
@@ -62,7 +66,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
_userService,
_pushNotificationService,
Substitute.For<ILogger<CreatePremiumCloudHostedSubscriptionCommand>>(),
_pricingClient);
_pricingClient,
_hasPaymentMethodQuery,
_updatePaymentMethodCommand);
}
[Theory, BitAutoData]
@@ -314,7 +320,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
}
[Theory, BitAutoData]
public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer(
public async Task Run_UserHasExistingGatewayCustomerIdAndPaymentMethod_UsesExistingCustomer(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
@@ -347,6 +353,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockInvoice = Substitute.For<Invoice>();
// Mock that the user has a payment method (this is the key difference from the credit purchase case)
_hasPaymentMethodQuery.Run(Arg.Any<User>()).Returns(true);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
@@ -358,6 +366,75 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any<User>(), Arg.Any<TokenizedPaymentMethod>(), Arg.Any<BillingAddress>());
}
[Theory, BitAutoData]
public async Task Run_UserPreviouslyPurchasedCreditWithoutPaymentMethod_UpdatesPaymentMethodAndCreatesSubscription(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = "existing_customer_123"; // Customer exists from previous credit purchase
paymentMethod.Type = TokenizablePaymentMethodType.Card;
paymentMethod.Token = "card_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "existing_customer_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem
{
CurrentPeriodEnd = DateTime.UtcNow.AddDays(30)
}
]
};
var mockInvoice = Substitute.For<Invoice>();
MaskedPaymentMethod mockMaskedPaymentMethod = new MaskedCard
{
Brand = "visa",
Last4 = "1234",
Expiration = "12/2025"
};
// Mock that the user does NOT have a payment method (simulating credit purchase scenario)
_hasPaymentMethodQuery.Run(Arg.Any<User>()).Returns(false);
_updatePaymentMethodCommand.Run(Arg.Any<User>(), Arg.Any<TokenizedPaymentMethod>(), Arg.Any<BillingAddress>())
.Returns(mockMaskedPaymentMethod);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
// Verify that update payment method was called (new behavior for credit purchase case)
await _updatePaymentMethodCommand.Received(1).Run(user, paymentMethod, billingAddress);
// Verify GetCustomerOrThrow was called after updating payment method
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
// Verify no new customer was created
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
// Verify subscription was created
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
// Verify user was updated correctly
Assert.True(user.Premium);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]

View File

@@ -0,0 +1,527 @@
using System.Net;
using Bit.Core.Billing;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using RichardSzalay.MockHttp;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Billing.Pricing;
[SutProviderCustomize]
public class PricingClientTests
{
#region GetLookupKey Tests (via GetPlan)
[Fact]
public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagEnabled_UsesFamilies2025LookupKey()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var planJson = CreatePlanJson("families-2025", "Families 2025", "families", 40M, "price_id");
mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/families-2025")
.Respond("application/json", planJson);
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.NotNull(result);
Assert.Equal(PlanType.FamiliesAnnually2025, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
[Fact]
public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagDisabled_UsesFamiliesLookupKey()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id");
mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/families")
.Respond("application/json", planJson);
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.NotNull(result);
// PreProcessFamiliesPreMigrationPlan should change "families" to "families-2025" when FF is disabled
Assert.Equal(PlanType.FamiliesAnnually2025, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
#endregion
#region PreProcessFamiliesPreMigrationPlan Tests (via GetPlan)
[Fact]
public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagDisabled_ReturnsFamiliesAnnually2025PlanType()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
// billing-pricing returns "families" lookup key because the flag is off
var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id");
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.NotNull(result);
// PreProcessFamiliesPreMigrationPlan should convert the families lookup key to families-2025
// and the PlanAdapter should assign the correct FamiliesAnnually2025 plan type
Assert.Equal(PlanType.FamiliesAnnually2025, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
[Fact]
public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagEnabled_ReturnsFamiliesAnnually2025PlanType()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var planJson = CreatePlanJson("families-2025", "Families", "families", 40M, "price_id");
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.NotNull(result);
// PreProcessFamiliesPreMigrationPlan should ignore the lookup key because the flag is on
// and the PlanAdapter should assign the correct FamiliesAnnually2025 plan type
Assert.Equal(PlanType.FamiliesAnnually2025, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
[Fact]
public async Task GetPlan_WithFamiliesAnnuallyAndFeatureFlagEnabled_ReturnsFamiliesAnnuallyPlanType()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id");
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually);
// Assert
Assert.NotNull(result);
// PreProcessFamiliesPreMigrationPlan should ignore the lookup key because the flag is on
// and the PlanAdapter should assign the correct FamiliesAnnually plan type
Assert.Equal(PlanType.FamiliesAnnually, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
[Fact]
public async Task GetPlan_WithOtherLookupKey_KeepsLookupKeyUnchanged()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var planJson = CreatePlanJson("enterprise-annually", "Enterprise", "enterprise", 144M, "price_id");
mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/enterprise-annually")
.Respond("application/json", planJson);
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond("application/json", planJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.EnterpriseAnnually);
// Assert
Assert.NotNull(result);
Assert.Equal(PlanType.EnterpriseAnnually, result.Type);
mockHttp.VerifyNoOutstandingExpectation();
}
#endregion
#region ListPlans Tests
[Fact]
public async Task ListPlans_WithFeatureFlagDisabled_ReturnsListWithPreProcessing()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
// biling-pricing would return "families" because the flag is disabled
var plansJson = $@"[
{CreatePlanJson("families", "Families", "families", 40M, "price_id")},
{CreatePlanJson("enterprise-annually", "Enterprise", "enterprise", 144M, "price_id")}
]";
mockHttp.When(HttpMethod.Get, "*/plans/organization")
.Respond("application/json", plansJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.ListPlans();
// Assert
Assert.NotNull(result);
Assert.Equal(2, result.Count);
// First plan should have been preprocessed from "families" to "families-2025"
Assert.Equal(PlanType.FamiliesAnnually2025, result[0].Type);
// Second plan should remain unchanged
Assert.Equal(PlanType.EnterpriseAnnually, result[1].Type);
mockHttp.VerifyNoOutstandingExpectation();
}
[Fact]
public async Task ListPlans_WithFeatureFlagEnabled_ReturnsListWithoutPreProcessing()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
var plansJson = $@"[
{CreatePlanJson("families", "Families", "families", 40M, "price_id")}
]";
mockHttp.When(HttpMethod.Get, "*/plans/organization")
.Respond("application/json", plansJson);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.ListPlans();
// Assert
Assert.NotNull(result);
Assert.Single(result);
// Plan should remain as FamiliesAnnually when FF is enabled
Assert.Equal(PlanType.FamiliesAnnually, result[0].Type);
mockHttp.VerifyNoOutstandingExpectation();
}
#endregion
#region GetPlan - Additional Coverage
[Theory, BitAutoData]
public async Task GetPlan_WhenSelfHosted_ReturnsNull(
SutProvider<PricingClient> sutProvider)
{
// Arrange
var globalSettings = sutProvider.GetDependency<GlobalSettings>();
globalSettings.SelfHosted = true;
// Act
var result = await sutProvider.Sut.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.Null(result);
}
[Theory, BitAutoData]
public async Task GetPlan_WhenPricingServiceDisabled_ReturnsStaticStorePlan(
SutProvider<PricingClient> sutProvider)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.UsePricingService)
.Returns(false);
// Act
var result = await sutProvider.Sut.GetPlan(PlanType.FamiliesAnnually);
// Assert
Assert.NotNull(result);
Assert.Equal(PlanType.FamiliesAnnually, result.Type);
}
[Theory, BitAutoData]
public async Task GetPlan_WhenLookupKeyNotFound_ReturnsNull(
SutProvider<PricingClient> sutProvider)
{
// Arrange
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.UsePricingService)
.Returns(true);
// Act - Using PlanType that doesn't have a lookup key mapping
var result = await sutProvider.Sut.GetPlan(unchecked((PlanType)999));
// Assert
Assert.Null(result);
}
[Fact]
public async Task GetPlan_WhenPricingServiceReturnsNotFound_ReturnsNull()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond(HttpStatusCode.NotFound);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act
var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025);
// Assert
Assert.Null(result);
}
[Fact]
public async Task GetPlan_WhenPricingServiceReturnsError_ThrowsBillingException()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
mockHttp.When(HttpMethod.Get, "*/plans/organization/*")
.Respond(HttpStatusCode.InternalServerError);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true);
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act & Assert
await Assert.ThrowsAsync<BillingException>(() =>
pricingClient.GetPlan(PlanType.FamiliesAnnually2025));
}
#endregion
#region ListPlans - Additional Coverage
[Theory, BitAutoData]
public async Task ListPlans_WhenSelfHosted_ReturnsEmptyList(
SutProvider<PricingClient> sutProvider)
{
// Arrange
var globalSettings = sutProvider.GetDependency<GlobalSettings>();
globalSettings.SelfHosted = true;
// Act
var result = await sutProvider.Sut.ListPlans();
// Assert
Assert.NotNull(result);
Assert.Empty(result);
}
[Theory, BitAutoData]
public async Task ListPlans_WhenPricingServiceDisabled_ReturnsStaticStorePlans(
SutProvider<PricingClient> sutProvider)
{
// Arrange
sutProvider.GetDependency<GlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.UsePricingService)
.Returns(false);
// Act
var result = await sutProvider.Sut.ListPlans();
// Assert
Assert.NotNull(result);
Assert.NotEmpty(result);
Assert.Equal(StaticStore.Plans.Count(), result.Count);
}
[Fact]
public async Task ListPlans_WhenPricingServiceReturnsError_ThrowsBillingException()
{
// Arrange
var mockHttp = new MockHttpMessageHandler();
mockHttp.When(HttpMethod.Get, "*/plans/organization")
.Respond(HttpStatusCode.InternalServerError);
var featureService = Substitute.For<IFeatureService>();
featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true);
var globalSettings = new GlobalSettings { SelfHosted = false };
var httpClient = new HttpClient(mockHttp)
{
BaseAddress = new Uri("https://test.com/")
};
var logger = Substitute.For<ILogger<PricingClient>>();
var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger);
// Act & Assert
await Assert.ThrowsAsync<BillingException>(() =>
pricingClient.ListPlans());
}
#endregion
private static string CreatePlanJson(
string lookupKey,
string name,
string tier,
decimal seatsPrice,
string seatsStripePriceId,
int seatsQuantity = 1)
{
return $@"{{
""lookupKey"": ""{lookupKey}"",
""name"": ""{name}"",
""tier"": ""{tier}"",
""features"": [],
""seats"": {{
""type"": ""packaged"",
""quantity"": {seatsQuantity},
""price"": {seatsPrice},
""stripePriceId"": ""{seatsStripePriceId}""
}},
""canUpgradeTo"": [],
""additionalData"": {{
""nameLocalizationKey"": ""{lookupKey}Name"",
""descriptionLocalizationKey"": ""{lookupKey}Description""
}}
}}";
}
}

View File

@@ -38,31 +38,32 @@ public class OrganizationBillingServiceTests
var subscriberService = sutProvider.GetDependency<ISubscriberService>();
var organizationSeatCount = new OrganizationSeatCounts { Users = 1, Sponsored = 0 };
var customer = new Customer
{
Discount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = ["product_id"]
}
}
}
};
var customer = new Customer();
subscriberService
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options =>
options.Expand.Contains("discount.coupon.applies_to")))
.GetCustomer(organization)
.Returns(customer);
subscriberService.GetSubscription(organization).Returns(new Subscription
{
Items = new StripeList<SubscriptionItem>
subscriberService.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to"))).Returns(new Subscription
{
Data =
Discounts =
[
new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.SecretsManagerStandalone,
AppliesTo = new CouponAppliesTo
{
Products = ["product_id"]
}
}
}
],
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem
{
@@ -72,8 +73,8 @@ public class OrganizationBillingServiceTests
}
}
]
}
});
}
});
sutProvider.GetDependency<IOrganizationRepository>()
.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id)
@@ -109,11 +110,12 @@ public class OrganizationBillingServiceTests
// Set up subscriber service to return null for customer
subscriberService
.GetCustomer(organization, Arg.Is<CustomerGetOptions>(options => options.Expand.FirstOrDefault() == "discount.coupon.applies_to"))
.GetCustomer(organization)
.Returns((Customer)null);
// Set up subscriber service to return null for subscription
subscriberService.GetSubscription(organization).Returns((Subscription)null);
subscriberService.GetSubscription(organization, Arg.Is<SubscriptionGetOptions>(options =>
options.Expand.Contains("discounts.coupon.applies_to"))).Returns((Subscription)null);
var metadata = await sutProvider.Sut.GetMetadata(organizationId);

View File

@@ -28,6 +28,9 @@
<None Remove="Utilities\data\embeddedResource.txt" />
</ItemGroup>
<ItemGroup>
<!-- Email templates uses .hbs extension, they must be included for emails to work -->
<EmbeddedResource Include="**\*.hbs" />
<EmbeddedResource Include="Utilities\data\embeddedResource.txt" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,497 @@
using Bit.Core.Models.Business;
using Bit.Test.Common.AutoFixture.Attributes;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Models.Business;
public class BillingCustomerDiscountTests
{
[Theory]
[BitAutoData]
public void Constructor_PercentageDiscount_SetsIdActivePercentOffAndAppliesTo(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 25.5m,
AmountOff = null,
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "product1", "product2" }
}
},
End = null // Active discount
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(couponId, result.Id);
Assert.True(result.Active);
Assert.Equal(25.5m, result.PercentOff);
Assert.Null(result.AmountOff);
Assert.NotNull(result.AppliesTo);
Assert.Equal(2, result.AppliesTo.Count);
Assert.Contains("product1", result.AppliesTo);
Assert.Contains("product2", result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_AmountDiscount_ConvertsFromCentsToDollars(string couponId)
{
// Arrange - Stripe sends 1400 cents for $14.00
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = null,
AmountOff = 1400, // 1400 cents
AppliesTo = new CouponAppliesTo
{
Products = new List<string>()
}
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(couponId, result.Id);
Assert.True(result.Active);
Assert.Null(result.PercentOff);
Assert.Equal(14.00m, result.AmountOff); // Converted to dollars
Assert.NotNull(result.AppliesTo);
Assert.Empty(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_InactiveDiscount_SetsActiveToFalse(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 15m
},
End = DateTime.UtcNow.AddDays(-1) // Expired discount
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(couponId, result.Id);
Assert.False(result.Active);
Assert.Equal(15m, result.PercentOff);
}
[Fact]
public void Constructor_NullCoupon_SetsDiscountPropertiesToNull()
{
// Arrange
var discount = new Discount
{
Coupon = null,
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.Id);
Assert.True(result.Active);
Assert.Null(result.PercentOff);
Assert.Null(result.AmountOff);
Assert.Null(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_NullAmountOff_SetsAmountOffToNull(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 10m,
AmountOff = null
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_ZeroAmountOff_ConvertsCorrectly(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
AmountOff = 0
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(0m, result.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_LargeAmountOff_ConvertsCorrectly(string couponId)
{
// Arrange - $100.00 discount
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
AmountOff = 10000 // 10000 cents = $100.00
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(100.00m, result.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_SmallAmountOff_ConvertsCorrectly(string couponId)
{
// Arrange - $0.50 discount
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
AmountOff = 50 // 50 cents = $0.50
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(0.50m, result.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_BothDiscountTypes_SetsPercentOffAndAmountOff(string couponId)
{
// Arrange - Coupon with both percentage and amount (edge case)
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 20m,
AmountOff = 500 // $5.00
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(20m, result.PercentOff);
Assert.Equal(5.00m, result.AmountOff);
}
[Theory]
[BitAutoData]
public void Constructor_WithNullAppliesTo_SetsAppliesToNull(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 10m,
AppliesTo = null
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_WithNullProductsList_SetsAppliesToNull(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 10m,
AppliesTo = new CouponAppliesTo
{
Products = null
}
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_WithDecimalAmountOff_RoundsCorrectly(string couponId)
{
// Arrange - 1425 cents = $14.25
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
AmountOff = 1425
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Equal(14.25m, result.AmountOff);
}
[Fact]
public void Constructor_DefaultConstructor_InitializesAllPropertiesToNullOrFalse()
{
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount();
// Assert
Assert.Null(result.Id);
Assert.False(result.Active);
Assert.Null(result.PercentOff);
Assert.Null(result.AmountOff);
Assert.Null(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_WithFutureEndDate_SetsActiveToFalse(string couponId)
{
// Arrange - Discount expires in the future
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 20m
},
End = DateTime.UtcNow.AddDays(30) // Expires in 30 days
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.False(result.Active); // Should be inactive because End is not null
}
[Theory]
[BitAutoData]
public void Constructor_WithPastEndDate_SetsActiveToFalse(string couponId)
{
// Arrange - Discount already expired
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 20m
},
End = DateTime.UtcNow.AddDays(-30) // Expired 30 days ago
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.False(result.Active); // Should be inactive because End is not null
}
[Fact]
public void Constructor_WithNullCouponId_SetsIdToNull()
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = null,
PercentOff = 20m
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.Id);
Assert.True(result.Active);
Assert.Equal(20m, result.PercentOff);
}
[Theory]
[BitAutoData]
public void Constructor_WithNullPercentOff_SetsPercentOffToNull(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = null,
AmountOff = 1000
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.Null(result.PercentOff);
Assert.Equal(10.00m, result.AmountOff);
}
[Fact]
public void Constructor_WithCompleteStripeDiscount_MapsAllProperties()
{
// Arrange - Comprehensive test with all Stripe Discount properties set
var discount = new Discount
{
Coupon = new Coupon
{
Id = "premium_discount_2024",
PercentOff = 25m,
AmountOff = 1500, // $15.00
AppliesTo = new CouponAppliesTo
{
Products = new List<string> { "prod_premium", "prod_family", "prod_teams" }
}
},
End = null // Active
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert - Verify all properties mapped correctly
Assert.Equal("premium_discount_2024", result.Id);
Assert.True(result.Active);
Assert.Equal(25m, result.PercentOff);
Assert.Equal(15.00m, result.AmountOff);
Assert.NotNull(result.AppliesTo);
Assert.Equal(3, result.AppliesTo.Count);
Assert.Contains("prod_premium", result.AppliesTo);
Assert.Contains("prod_family", result.AppliesTo);
Assert.Contains("prod_teams", result.AppliesTo);
}
[Fact]
public void Constructor_WithMinimalStripeDiscount_HandlesNullsGracefully()
{
// Arrange - Minimal Stripe Discount with most properties null
var discount = new Discount
{
Coupon = new Coupon
{
Id = null,
PercentOff = null,
AmountOff = null,
AppliesTo = null
},
End = DateTime.UtcNow.AddDays(10) // Has end date
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert - Should handle all nulls gracefully
Assert.Null(result.Id);
Assert.False(result.Active);
Assert.Null(result.PercentOff);
Assert.Null(result.AmountOff);
Assert.Null(result.AppliesTo);
}
[Theory]
[BitAutoData]
public void Constructor_WithEmptyProductsList_PreservesEmptyList(string couponId)
{
// Arrange
var discount = new Discount
{
Coupon = new Coupon
{
Id = couponId,
PercentOff = 10m,
AppliesTo = new CouponAppliesTo
{
Products = new List<string>() // Empty but not null
}
},
End = null
};
// Act
var result = new SubscriptionInfo.BillingCustomerDiscount(discount);
// Assert
Assert.NotNull(result.AppliesTo);
Assert.Empty(result.AppliesTo);
}
}

View File

@@ -22,7 +22,7 @@ public class SecretsManagerSubscriptionUpdateTests
}
public static TheoryData<Plan> NonSmPlans =>
ToPlanTheory([PlanType.Custom, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019]);
ToPlanTheory([PlanType.Custom, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2025, PlanType.FamiliesAnnually2019]);
public static TheoryData<Plan> SmPlans => ToPlanTheory([
PlanType.EnterpriseAnnually2019,

View File

@@ -0,0 +1,125 @@
using Bit.Core.Models.Business;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Models.Business;
public class SubscriptionInfoTests
{
[Fact]
public void BillingSubscriptionItem_NullPlan_HandlesGracefully()
{
// Arrange - SubscriptionItem with null Plan
var subscriptionItem = new SubscriptionItem
{
Plan = null,
Quantity = 1
};
// Act
var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem);
// Assert - Should handle null Plan gracefully
Assert.Null(result.ProductId);
Assert.Null(result.Name);
Assert.Equal(0m, result.Amount); // Defaults to 0 when Plan is null
Assert.Null(result.Interval);
Assert.Equal(1, result.Quantity);
Assert.False(result.SponsoredSubscriptionItem);
Assert.False(result.AddonSubscriptionItem);
}
[Fact]
public void BillingSubscriptionItem_NullAmount_SetsToZero()
{
// Arrange - SubscriptionItem with Plan but null Amount
var subscriptionItem = new SubscriptionItem
{
Plan = new Plan
{
ProductId = "prod_test",
Nickname = "Test Plan",
Amount = null, // Null amount
Interval = "month"
},
Quantity = 1
};
// Act
var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem);
// Assert - Should default to 0 when Amount is null
Assert.Equal("prod_test", result.ProductId);
Assert.Equal("Test Plan", result.Name);
Assert.Equal(0m, result.Amount); // Business rule: defaults to 0 when null
Assert.Equal("month", result.Interval);
Assert.Equal(1, result.Quantity);
}
[Fact]
public void BillingSubscriptionItem_ZeroAmount_PreservesZero()
{
// Arrange - SubscriptionItem with Plan and zero Amount
var subscriptionItem = new SubscriptionItem
{
Plan = new Plan
{
ProductId = "prod_test",
Nickname = "Test Plan",
Amount = 0, // Zero amount (0 cents)
Interval = "month"
},
Quantity = 1
};
// Act
var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem);
// Assert - Should preserve zero amount
Assert.Equal("prod_test", result.ProductId);
Assert.Equal("Test Plan", result.Name);
Assert.Equal(0m, result.Amount); // Zero amount preserved
Assert.Equal("month", result.Interval);
}
[Fact]
public void BillingUpcomingInvoice_ZeroAmountDue_ConvertsToZero()
{
// Arrange - Invoice with zero AmountDue
// Note: Stripe's Invoice.AmountDue is non-nullable long, so we test with 0
// The null-coalescing operator (?? 0) in the constructor handles the case where
// ConvertFromStripeMinorUnits returns null, but since AmountDue is non-nullable,
// this test verifies the conversion path works correctly for zero values
var invoice = new Invoice
{
AmountDue = 0, // Zero amount due (0 cents)
Created = DateTime.UtcNow
};
// Act
var result = new SubscriptionInfo.BillingUpcomingInvoice(invoice);
// Assert - Should convert zero correctly
Assert.Equal(0m, result.Amount);
Assert.NotNull(result.Date);
}
[Fact]
public void BillingUpcomingInvoice_ValidAmountDue_ConvertsCorrectly()
{
// Arrange - Invoice with valid AmountDue
var invoice = new Invoice
{
AmountDue = 2500, // 2500 cents = $25.00
Created = DateTime.UtcNow
};
// Act
var result = new SubscriptionInfo.BillingUpcomingInvoice(invoice);
// Assert - Should convert correctly
Assert.Equal(25.00m, result.Amount); // Converted from cents
Assert.NotNull(result.Date);
}
}

View File

@@ -0,0 +1,172 @@
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Settings;
using Bit.Core.Test.Platform.Mailer.TestMail;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Platform.Mailer;
public class HandlebarMailRendererTests
{
[Fact]
public async Task RenderAsync_ReturnsExpectedHtmlAndTxt()
{
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var globalSettings = new GlobalSettings { SelfHosted = false };
var renderer = new HandlebarMailRenderer(logger, globalSettings);
var view = new TestMailView { Name = "John Smith" };
var (html, txt) = await renderer.RenderAsync(view);
Assert.Equal("Hello <b>John Smith</b>", html.Trim());
Assert.Equal("Hello John Smith", txt.Trim());
}
[Fact]
public async Task RenderAsync_LoadsFromDisk_WhenSelfHostedAndFileExists()
{
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());
Directory.CreateDirectory(tempDir);
try
{
var globalSettings = new GlobalSettings
{
SelfHosted = true,
MailTemplateDirectory = tempDir
};
// Create test template files on disk
var htmlTemplatePath = Path.Combine(tempDir, "Bit.Core.Test.Platform.Mailer.TestMail.TestMailView.html.hbs");
var txtTemplatePath = Path.Combine(tempDir, "Bit.Core.Test.Platform.Mailer.TestMail.TestMailView.text.hbs");
await File.WriteAllTextAsync(htmlTemplatePath, "Custom HTML: <b>{{Name}}</b>");
await File.WriteAllTextAsync(txtTemplatePath, "Custom TXT: {{Name}}");
var renderer = new HandlebarMailRenderer(logger, globalSettings);
var view = new TestMailView { Name = "Jane Doe" };
var (html, txt) = await renderer.RenderAsync(view);
Assert.Equal("Custom HTML: <b>Jane Doe</b>", html.Trim());
Assert.Equal("Custom TXT: Jane Doe", txt.Trim());
}
finally
{
// Cleanup
if (Directory.Exists(tempDir))
{
Directory.Delete(tempDir, true);
}
}
}
[Theory]
[InlineData("../../../etc/passwd")]
[InlineData("../../../../malicious.txt")]
[InlineData("../../malicious.txt")]
[InlineData("../malicious.txt")]
public async Task ReadSourceFromDiskAsync_PrevenetsPathTraversal_WhenMaliciousPathProvided(string maliciousPath)
{
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());
Directory.CreateDirectory(tempDir);
try
{
var globalSettings = new GlobalSettings
{
SelfHosted = true,
MailTemplateDirectory = tempDir
};
// Create a malicious file outside the template directory
var maliciousFile = Path.Combine(Path.GetTempPath(), "malicious.txt");
await File.WriteAllTextAsync(maliciousFile, "Malicious Content");
var renderer = new HandlebarMailRenderer(logger, globalSettings);
// Use reflection to call the private ReadSourceFromDiskAsync method
var method = typeof(HandlebarMailRenderer).GetMethod("ReadSourceFromDiskAsync",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var task = (Task<string?>)method!.Invoke(renderer, new object[] { maliciousPath })!;
var result = await task;
// Should return null and not load the malicious file
Assert.Null(result);
// Verify that a warning was logged for the path traversal attempt
logger.Received(1).Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
// Cleanup malicious file
if (File.Exists(maliciousFile))
{
File.Delete(maliciousFile);
}
}
finally
{
// Cleanup
if (Directory.Exists(tempDir))
{
Directory.Delete(tempDir, true);
}
}
}
[Fact]
public async Task ReadSourceFromDiskAsync_AllowsValidFileWithDifferentCase_WhenCaseInsensitiveFileSystem()
{
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());
Directory.CreateDirectory(tempDir);
try
{
var globalSettings = new GlobalSettings
{
SelfHosted = true,
MailTemplateDirectory = tempDir
};
// Create a test template file
var templateFileName = "TestTemplate.hbs";
var templatePath = Path.Combine(tempDir, templateFileName);
await File.WriteAllTextAsync(templatePath, "Test Content");
var renderer = new HandlebarMailRenderer(logger, globalSettings);
// Try to read with different case (should work on case-insensitive file systems like Windows/macOS)
var method = typeof(HandlebarMailRenderer).GetMethod("ReadSourceFromDiskAsync",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var task = (Task<string?>)method!.Invoke(renderer, new object[] { templateFileName })!;
var result = await task;
// Should successfully read the file
Assert.Equal("Test Content", result);
// Verify no warning was logged
logger.DidNotReceive().Log(
LogLevel.Warning,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception, string>>());
}
finally
{
// Cleanup
if (Directory.Exists(tempDir))
{
Directory.Delete(tempDir, true);
}
}
}
}

View File

@@ -0,0 +1,42 @@
using Bit.Core.Models.Mail;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Mailer;
using Bit.Core.Settings;
using Bit.Core.Test.Platform.Mailer.TestMail;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Platform.Mailer;
public class MailerTest
{
[Fact]
public async Task SendEmailAsync()
{
var logger = Substitute.For<ILogger<HandlebarMailRenderer>>();
var globalSettings = new GlobalSettings { SelfHosted = false };
var deliveryService = Substitute.For<IMailDeliveryService>();
var mailer = new Core.Platform.Mail.Mailer.Mailer(new HandlebarMailRenderer(logger, globalSettings), deliveryService);
var mail = new TestMail.TestMail()
{
ToEmails = ["test@bw.com"],
View = new TestMailView() { Name = "John Smith" }
};
MailMessage? sentMessage = null;
await deliveryService.SendEmailAsync(Arg.Do<MailMessage>(message =>
sentMessage = message
));
await mailer.SendEmail(mail);
Assert.NotNull(sentMessage);
Assert.Contains("test@bw.com", sentMessage.ToEmails);
Assert.Equal("Test Email", sentMessage.Subject);
Assert.Equivalent("Hello John Smith", sentMessage.TextContent.Trim());
Assert.Equivalent("Hello <b>John Smith</b>", sentMessage.HtmlContent.Trim());
}
}

View File

@@ -0,0 +1,13 @@
using Bit.Core.Platform.Mail.Mailer;
namespace Bit.Core.Test.Platform.Mailer.TestMail;
public class TestMailView : BaseMailView
{
public required string Name { get; init; }
}
public class TestMail : BaseMail<TestMailView>
{
public override string Subject { get; } = "Test Email";
}

View File

@@ -0,0 +1 @@
Hello <b>{{ Name }}</b>

View File

@@ -0,0 +1 @@
Hello {{ Name }}

View File

@@ -1,7 +1,7 @@
using Amazon.SimpleEmail;
using Amazon.SimpleEmail.Model;
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;

View File

@@ -6,7 +6,10 @@ using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Business;
using Bit.Core.Entities;
using Bit.Core.Models.Mail;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Mail.Enqueuing;
using Bit.Core.Services;
using Bit.Core.Services.Mail;
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
@@ -265,4 +268,115 @@ public class HandlebarsMailServiceTests
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any<MailMessage>());
}
[Fact]
public async Task SendIndividualUserWelcomeEmailAsync_SendsCorrectEmail()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Email = "test@example.com"
};
// Act
await _sut.SendIndividualUserWelcomeEmailAsync(user);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is<MailMessage>(m =>
m.MetaData != null &&
m.ToEmails.Contains("test@example.com") &&
m.Subject == "Welcome to Bitwarden!" &&
m.Category == "Welcome"));
}
[Fact]
public async Task SendOrganizationUserWelcomeEmailAsync_SendsCorrectEmailWithOrganizationName()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Email = "user@company.com"
};
var organizationName = "Bitwarden Corp";
// Act
await _sut.SendOrganizationUserWelcomeEmailAsync(user, organizationName);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is<MailMessage>(m =>
m.MetaData != null &&
m.ToEmails.Contains("user@company.com") &&
m.Subject == "Welcome to Bitwarden!" &&
m.HtmlContent.Contains("Bitwarden Corp") &&
m.Category == "Welcome"));
}
[Fact]
public async Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync_SendsCorrectEmailWithFamilyTemplate()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Email = "family@example.com"
};
var familyOrganizationName = "Smith Family";
// Act
await _sut.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, familyOrganizationName);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is<MailMessage>(m =>
m.MetaData != null &&
m.ToEmails.Contains("family@example.com") &&
m.Subject == "Welcome to Bitwarden!" &&
m.HtmlContent.Contains("Smith Family") &&
m.Category == "Welcome"));
}
[Theory]
[InlineData("Acme Corp", "Acme Corp")]
[InlineData("Company & Associates", "Company &amp; Associates")]
[InlineData("Test \"Quoted\" Org", "Test &quot;Quoted&quot; Org")]
public async Task SendOrganizationUserWelcomeEmailAsync_SanitizesOrganizationNameForEmail(string inputOrgName, string expectedSanitized)
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Email = "test@example.com"
};
// Act
await _sut.SendOrganizationUserWelcomeEmailAsync(user, inputOrgName);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is<MailMessage>(m =>
m.HtmlContent.Contains(expectedSanitized) &&
!m.HtmlContent.Contains("<script>") && // Ensure script tags are removed
m.Category == "Welcome"));
}
[Theory]
[InlineData("test@example.com")]
[InlineData("user+tag@domain.co.uk")]
[InlineData("admin@organization.org")]
public async Task SendIndividualUserWelcomeEmailAsync_HandlesVariousEmailFormats(string email)
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Email = email
};
// Act
await _sut.SendIndividualUserWelcomeEmailAsync(user);
// Assert
await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is<MailMessage>(m =>
m.ToEmails.Contains(email)));
}
}

View File

@@ -1,4 +1,4 @@
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
using NSubstitute;

View File

@@ -1,5 +1,5 @@
using Bit.Core.Models.Mail;
using Bit.Core.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Settings;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;

View File

@@ -3,6 +3,7 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models.StaticStore.Plans;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
@@ -515,4 +516,399 @@ public class StripePaymentServiceTests
options.CustomerDetails.TaxExempt == StripeConstants.TaxExempt.Reverse
));
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithCustomerDiscount_ReturnsDiscountFromCustomer(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var customerDiscount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
PercentOff = 20m,
AmountOff = 1400
},
End = null
};
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = customerDiscount
},
Discounts = new List<Discount>(), // Empty list
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Equal(20m, result.CustomerDiscount.PercentOff);
Assert.Equal(14.00m, result.CustomerDiscount.AmountOff); // Converted from cents
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithoutCustomerDiscount_FallsBackToSubscriptionDiscounts(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var subscriptionDiscount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
PercentOff = 15m,
AmountOff = null
},
End = null
};
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = null // No customer discount
},
Discounts = new List<Discount> { subscriptionDiscount },
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Should use subscription discount as fallback
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Equal(15m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithBothDiscounts_PrefersCustomerDiscount(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var customerDiscount = new Discount
{
Coupon = new Coupon
{
Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount,
PercentOff = 25m
},
End = null
};
var subscriptionDiscount = new Discount
{
Coupon = new Coupon
{
Id = "different-coupon-id",
PercentOff = 10m
},
End = null
};
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = customerDiscount // Should prefer this
},
Discounts = new List<Discount> { subscriptionDiscount },
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Should prefer customer discount over subscription discount
Assert.NotNull(result.CustomerDiscount);
Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id);
Assert.Equal(25m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithNoDiscounts_ReturnsNullDiscount(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = null
},
Discounts = new List<Discount>(), // Empty list, no discounts
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithMultipleSubscriptionDiscounts_SelectsFirstDiscount(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange - Multiple subscription-level discounts, no customer discount
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var firstDiscount = new Discount
{
Coupon = new Coupon
{
Id = "coupon-10-percent",
PercentOff = 10m
},
End = null
};
var secondDiscount = new Discount
{
Coupon = new Coupon
{
Id = "coupon-20-percent",
PercentOff = 20m
},
End = null
};
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = null // No customer discount
},
// Multiple subscription discounts - FirstOrDefault() should select the first one
Discounts = new List<Discount> { firstDiscount, secondDiscount },
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Should select the first discount from the list (FirstOrDefault() behavior)
Assert.NotNull(result.CustomerDiscount);
Assert.Equal("coupon-10-percent", result.CustomerDiscount.Id);
Assert.Equal(10m, result.CustomerDiscount.PercentOff);
// Verify the second discount was not selected
Assert.NotEqual("coupon-20-percent", result.CustomerDiscount.Id);
Assert.NotEqual(20m, result.CustomerDiscount.PercentOff);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithNullCustomer_HandlesGracefully(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange - Subscription with null Customer (defensive null check scenario)
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = null, // Customer not expanded or null
Discounts = new List<Discount>(), // Empty discounts
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Should handle null Customer gracefully without throwing NullReferenceException
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithNullDiscounts_HandlesGracefully(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange - Subscription with null Discounts (defensive null check scenario)
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer
{
Discount = null // No customer discount
},
Discounts = null, // Discounts not expanded or null
Items = new StripeList<SubscriptionItem> { Data = [] }
};
sutProvider.GetDependency<IStripeAdapter>()
.SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Should handle null Discounts gracefully without throwing NullReferenceException
Assert.Null(result.CustomerDiscount);
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_VerifiesCorrectExpandOptions(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = "cus_test123";
subscriber.GatewaySubscriptionId = "sub_test123";
var subscription = new Subscription
{
Id = "sub_test123",
Status = "active",
CollectionMethod = "charge_automatically",
Customer = new Customer { Discount = null },
Discounts = new List<Discount>(), // Empty list
Items = new StripeList<SubscriptionItem> { Data = [] }
};
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter
.SubscriptionGetAsync(
Arg.Any<string>(),
Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
// Act
await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert - Verify expand options are correct
await stripeAdapter.Received(1).SubscriptionGetAsync(
subscriber.GatewaySubscriptionId,
Arg.Is<SubscriptionGetOptions>(o =>
o.Expand.Contains("customer.discount.coupon.applies_to") &&
o.Expand.Contains("discounts.coupon.applies_to") &&
o.Expand.Contains("test_clock")));
}
[Theory]
[BitAutoData]
public async Task GetSubscriptionAsync_WithEmptyGatewaySubscriptionId_ReturnsEmptySubscriptionInfo(
SutProvider<StripePaymentService> sutProvider,
User subscriber)
{
// Arrange
subscriber.GatewaySubscriptionId = null;
// Act
var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber);
// Assert
Assert.NotNull(result);
Assert.Null(result.Subscription);
Assert.Null(result.CustomerDiscount);
Assert.Null(result.UpcomingInvoice);
// Verify no Stripe API calls were made
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceive()
.SubscriptionGetAsync(Arg.Any<string>(), Arg.Any<SubscriptionGetOptions>());
}
}

View File

@@ -0,0 +1,171 @@
using Bit.Core.Settings;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using NSubstitute;
using StackExchange.Redis;
using Xunit;
using ZiggyCreatures.Caching.Fusion;
namespace Bit.Core.Test.Utilities;
public class ExtendedCacheServiceCollectionExtensionsTests
{
private readonly IServiceCollection _services;
private readonly GlobalSettings _globalSettings;
public ExtendedCacheServiceCollectionExtensionsTests()
{
_services = new ServiceCollection();
var config = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string?>())
.Build();
_globalSettings = new GlobalSettings();
config.GetSection("GlobalSettings").Bind(_globalSettings);
_services.TryAddSingleton(config);
_services.TryAddSingleton(_globalSettings);
_services.TryAddSingleton<IGlobalSettings>(_globalSettings);
_services.AddLogging();
}
[Fact]
public void TryAddFusionCoreServices_CustomSettings_OverridesDefaults()
{
var settings = CreateGlobalSettings(new Dictionary<string, string?>
{
{ "GlobalSettings:DistributedCache:Duration", "00:12:00" },
{ "GlobalSettings:DistributedCache:FailSafeMaxDuration", "01:30:00" },
{ "GlobalSettings:DistributedCache:FailSafeThrottleDuration", "00:01:00" },
{ "GlobalSettings:DistributedCache:EagerRefreshThreshold", "0.75" },
{ "GlobalSettings:DistributedCache:FactorySoftTimeout", "00:00:00.020" },
{ "GlobalSettings:DistributedCache:FactoryHardTimeout", "00:00:03" },
{ "GlobalSettings:DistributedCache:DistributedCacheSoftTimeout", "00:00:00.500" },
{ "GlobalSettings:DistributedCache:DistributedCacheHardTimeout", "00:00:01.500" },
{ "GlobalSettings:DistributedCache:JitterMaxDuration", "00:00:05" },
{ "GlobalSettings:DistributedCache:IsFailSafeEnabled", "false" },
{ "GlobalSettings:DistributedCache:AllowBackgroundDistributedCacheOperations", "false" },
});
_services.TryAddExtendedCacheServices(settings);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
var options = fusionCache.DefaultEntryOptions;
Assert.Equal(TimeSpan.FromMinutes(12), options.Duration);
Assert.False(options.IsFailSafeEnabled);
Assert.Equal(TimeSpan.FromHours(1.5), options.FailSafeMaxDuration);
Assert.Equal(TimeSpan.FromMinutes(1), options.FailSafeThrottleDuration);
Assert.Equal(0.75f, options.EagerRefreshThreshold);
Assert.Equal(TimeSpan.FromMilliseconds(20), options.FactorySoftTimeout);
Assert.Equal(TimeSpan.FromMilliseconds(3000), options.FactoryHardTimeout);
Assert.Equal(TimeSpan.FromSeconds(0.5), options.DistributedCacheSoftTimeout);
Assert.Equal(TimeSpan.FromSeconds(1.5), options.DistributedCacheHardTimeout);
Assert.False(options.AllowBackgroundDistributedCacheOperations);
Assert.Equal(TimeSpan.FromSeconds(5), options.JitterMaxDuration);
}
[Fact]
public void TryAddFusionCoreServices_DefaultSettings_ConfiguresExpectedValues()
{
_services.TryAddExtendedCacheServices(_globalSettings);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
var options = fusionCache.DefaultEntryOptions;
Assert.Equal(TimeSpan.FromMinutes(30), options.Duration);
Assert.True(options.IsFailSafeEnabled);
Assert.Equal(TimeSpan.FromHours(2), options.FailSafeMaxDuration);
Assert.Equal(TimeSpan.FromSeconds(30), options.FailSafeThrottleDuration);
Assert.Equal(0.9f, options.EagerRefreshThreshold);
Assert.Equal(TimeSpan.FromMilliseconds(100), options.FactorySoftTimeout);
Assert.Equal(TimeSpan.FromMilliseconds(1500), options.FactoryHardTimeout);
Assert.Equal(TimeSpan.FromSeconds(1), options.DistributedCacheSoftTimeout);
Assert.Equal(TimeSpan.FromSeconds(2), options.DistributedCacheHardTimeout);
Assert.True(options.AllowBackgroundDistributedCacheOperations);
Assert.Equal(TimeSpan.FromSeconds(2), options.JitterMaxDuration);
}
[Fact]
public void TryAddFusionCoreServices_MultipleCalls_OnlyConfiguresOnce()
{
var settings = CreateGlobalSettings(new Dictionary<string, string?>
{
{ "GlobalSettings:DistributedCache:Redis:ConnectionString", "localhost:6379" },
});
_services.AddSingleton(Substitute.For<IConnectionMultiplexer>());
_services.TryAddExtendedCacheServices(settings);
_services.TryAddExtendedCacheServices(settings);
_services.TryAddExtendedCacheServices(settings);
var registrations = _services.Where(s => s.ServiceType == typeof(IFusionCache)).ToList();
Assert.Single(registrations);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
Assert.NotNull(fusionCache);
}
[Fact]
public void TryAddFusionCoreServices_WithRedis_EnablesDistributedCacheAndBackplane()
{
var settings = CreateGlobalSettings(new Dictionary<string, string?>
{
{ "GlobalSettings:DistributedCache:Redis:ConnectionString", "localhost:6379" },
});
_services.AddSingleton(Substitute.For<IConnectionMultiplexer>());
_services.TryAddExtendedCacheServices(settings);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
Assert.True(fusionCache.HasDistributedCache);
Assert.True(fusionCache.HasBackplane);
}
[Fact]
public void TryAddFusionCoreServices_WithExistingRedis_EnablesDistributedCacheAndBackplane()
{
var settings = CreateGlobalSettings(new Dictionary<string, string?>
{
{ "GlobalSettings:DistributedCache:Redis:ConnectionString", "localhost:6379" },
});
_services.AddSingleton(Substitute.For<IConnectionMultiplexer>());
_services.AddSingleton(Substitute.For<IDistributedCache>());
_services.TryAddExtendedCacheServices(settings);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
Assert.True(fusionCache.HasDistributedCache);
Assert.True(fusionCache.HasBackplane);
var distributedCache = provider.GetRequiredService<IDistributedCache>();
Assert.NotNull(distributedCache);
}
[Fact]
public void TryAddFusionCoreServices_WithoutRedis_DisablesDistributedCacheAndBackplane()
{
_services.TryAddExtendedCacheServices(_globalSettings);
using var provider = _services.BuildServiceProvider();
var fusionCache = provider.GetRequiredService<IFusionCache>();
Assert.False(fusionCache.HasDistributedCache);
Assert.False(fusionCache.HasBackplane);
}
private static GlobalSettings CreateGlobalSettings(Dictionary<string, string?> data)
{
var config = new ConfigurationBuilder()
.AddInMemoryCollection(data)
.Build();
var settings = new GlobalSettings();
config.GetSection("GlobalSettings").Bind(settings);
return settings;
}
}

View File

@@ -13,7 +13,7 @@ public class StaticStoreTests
var plans = StaticStore.Plans.ToList();
Assert.NotNull(plans);
Assert.NotEmpty(plans);
Assert.Equal(22, plans.Count);
Assert.Equal(23, plans.Count);
}
[Theory]
@@ -34,8 +34,8 @@ public class StaticStoreTests
{
// Ref: https://daniel.haxx.se/blog/2025/05/16/detecting-malicious-unicode/
// URLs can contain unicode characters that to a computer would point to completely seperate domains but to the
// naked eye look completely identical. For example 'g' and 'ց' look incredibly similar but when included in a
// URL would lead you somewhere different. There is an opening for an attacker to contribute to Bitwarden with a
// naked eye look completely identical. For example 'g' and 'ց' look incredibly similar but when included in a
// URL would lead you somewhere different. There is an opening for an attacker to contribute to Bitwarden with a
// url update that could be missed in code review and then if they got a user to that URL Bitwarden could
// consider it equivalent with a cipher in the users vault and offer autofill when we should not.
// GitHub does now show a warning on non-ascii characters but it could still be missed.

View File

@@ -2286,6 +2286,63 @@ public class CipherServiceTests
.PushSyncCiphersAsync(deletingUserId);
}
[Theory]
[BitAutoData]
public async Task SoftDeleteAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, CipherDetails cipherDetails, SutProvider<CipherService> sutProvider)
{
cipherDetails.UserId = deletingUserId;
cipherDetails.OrganizationId = null;
cipherDetails.DeletedDate = null;
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});
await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId);
await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 1 && ids.First() == cipherDetails.Id));
}
[Theory]
[BitAutoData]
public async Task SoftDeleteManyAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, List<CipherDetails> ciphers, SutProvider<CipherService> sutProvider)
{
var cipherIds = ciphers.Select(c => c.Id).ToArray();
foreach (var cipher in ciphers)
{
cipher.UserId = deletingUserId;
cipher.OrganizationId = null;
cipher.Edit = true;
cipher.DeletedDate = null;
}
sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});
sutProvider.GetDependency<ICipherRepository>()
.GetManyByUserIdAsync(deletingUserId)
.Returns(ciphers);
await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false);
await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
}
private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
{
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);

View File

@@ -1,21 +1,63 @@
using System.Net.Http.Json;
using System.Net;
using System.Net.Http.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums;
using Bit.Core.Vault.Repositories;
using Bit.Events.Models;
namespace Bit.Events.IntegrationTest.Controllers;
public class CollectControllerTests
public class CollectControllerTests : IAsyncLifetime
{
// This is a very simple test, and should be updated to assert more things, but for now
// it ensures that the events startup doesn't throw any errors with fairly basic configuration.
[Fact]
public async Task Post_Works()
{
var eventsApplicationFactory = new EventsApplicationFactory();
var (accessToken, _) = await eventsApplicationFactory.LoginWithNewAccount();
var client = eventsApplicationFactory.CreateAuthedClient(accessToken);
private EventsApplicationFactory _factory = null!;
private HttpClient _client = null!;
private string _ownerEmail = null!;
private Guid _ownerId;
var response = await client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
public async Task InitializeAsync()
{
_factory = new EventsApplicationFactory();
_ownerEmail = $"integration-test+{Guid.NewGuid()}@bitwarden.com";
var (accessToken, _) = await _factory.LoginWithNewAccount(_ownerEmail);
_client = _factory.CreateAuthedClient(accessToken);
// Get the user ID
var userRepository = _factory.GetService<IUserRepository>();
var user = await userRepository.GetByEmailAsync(_ownerEmail);
_ownerId = user!.Id;
}
public Task DisposeAsync()
{
_client?.Dispose();
_factory?.Dispose();
return Task.CompletedTask;
}
[Fact]
public async Task Post_NullModel_ReturnsBadRequest()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>?>("collect", null);
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Post_EmptyModel_ReturnsBadRequest()
{
var response = await _client.PostAsJsonAsync("collect", Array.Empty<EventModel>());
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task Post_UserClientExportedVault_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
@@ -26,4 +68,425 @@ public class CollectControllerTests
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientAutofilled_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedPassword_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedHiddenField_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedHiddenField,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientCopiedCardCode_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientCopiedCardCode,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledCardNumberVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledCardNumberVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledCardCodeVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledCardCodeVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledHiddenFieldVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledHiddenFieldVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientToggledPasswordVisible_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientToggledPasswordVisible,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherClientViewed_WithValidCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEvent_WithoutCipherId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEvent_WithInvalidCipherId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = Guid.NewGuid(),
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithValidOrganization_Success()
{
var organization = await CreateOrganizationAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = organization.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithoutOrganizationId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_OrganizationClientExportedVault_WithInvalidOrganizationId_Success()
{
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = Guid.NewGuid(),
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_MultipleEvents_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var organization = await CreateOrganizationAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = organization.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherEventsBatch_MoreThan50Items_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
// Create 60 cipher events to test batching logic (should be processed in 2 batches of 50)
var events = Enumerable.Range(0, 60)
.Select(_ => new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
})
.ToList();
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect", events);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_UnsupportedEventType_Success()
{
// Testing with an event type not explicitly handled in the switch statement
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_LoggedIn,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_MixedValidAndInvalidEvents_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = Guid.NewGuid(), // Invalid cipher ID
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id, // Valid cipher ID
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
[Fact]
public async Task Post_CipherCaching_MultipleEventsForSameCipher_Success()
{
var cipher = await CreateCipherForUserAsync(_ownerId);
// Multiple events for the same cipher should use caching
var response = await _client.PostAsJsonAsync<IEnumerable<EventModel>>("collect",
[
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipher.Id,
Date = DateTime.UtcNow,
},
]);
response.EnsureSuccessStatusCode();
}
private async Task<Cipher> CreateCipherForUserAsync(Guid userId)
{
var cipherRepository = _factory.GetService<ICipherRepository>();
var cipher = new Cipher
{
Type = CipherType.Login,
UserId = userId,
Data = "{\"name\":\"Test Cipher\"}",
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await cipherRepository.CreateAsync(cipher);
return cipher;
}
private async Task<Organization> CreateOrganizationAsync(Guid ownerId)
{
var organizationRepository = _factory.GetService<IOrganizationRepository>();
var organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
var organization = new Organization
{
Name = "Test Organization",
BillingEmail = _ownerEmail,
Plan = "Free",
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await organizationRepository.CreateAsync(organization);
// Add the user as an owner of the organization
var organizationUser = new OrganizationUser
{
OrganizationId = organization.Id,
UserId = ownerId,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
};
await organizationUserRepository.CreateAsync(organizationUser);
return organization;
}
}

View File

@@ -0,0 +1,715 @@
using AutoFixture.Xunit2;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories;
using Bit.Events.Controllers;
using Bit.Events.Models;
using Microsoft.AspNetCore.Mvc;
using NSubstitute;
namespace Events.Test.Controllers;
public class CollectControllerTests
{
private readonly CollectController _sut;
private readonly ICurrentContext _currentContext;
private readonly IEventService _eventService;
private readonly ICipherRepository _cipherRepository;
private readonly IOrganizationRepository _organizationRepository;
public CollectControllerTests()
{
_currentContext = Substitute.For<ICurrentContext>();
_eventService = Substitute.For<IEventService>();
_cipherRepository = Substitute.For<ICipherRepository>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_sut = new CollectController(
_currentContext,
_eventService,
_cipherRepository,
_organizationRepository
);
}
[Fact]
public async Task Post_NullModel_ReturnsBadRequest()
{
var result = await _sut.Post(null);
Assert.IsType<BadRequestResult>(result);
}
[Fact]
public async Task Post_EmptyModel_ReturnsBadRequest()
{
var result = await _sut.Post(new List<EventModel>());
Assert.IsType<BadRequestResult>(result);
}
[Theory]
[AutoData]
public async Task Post_UserClientExportedVault_LogsUserEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogUserEventAsync(userId, EventType.User_ClientExportedVault, eventDate);
}
[Theory]
[AutoData]
public async Task Post_CipherAutofilled_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.Count() == 1 &&
tuples.First().Item1 == cipherDetails &&
tuples.First().Item2 == EventType.Cipher_ClientAutofilled &&
tuples.First().Item3 == eventDate
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedPassword_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedPassword,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedPassword
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedHiddenField_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedHiddenField,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedHiddenField
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientCopiedCardCode_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientCopiedCardCode,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientCopiedCardCode
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledCardNumberVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledCardNumberVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledCardNumberVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledCardCodeVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledCardCodeVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledCardCodeVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledHiddenFieldVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledHiddenFieldVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledHiddenFieldVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientToggledPasswordVisible_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientToggledPasswordVisible,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientToggledPasswordVisible
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherClientViewed_WithValidCipher_LogsCipherEvent(Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipherId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item2 == EventType.Cipher_ClientViewed
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithoutCipherId_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default, default);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_WithoutOrgId_SkipsEvent(Guid userId, Guid cipherId)
{
_currentContext.UserId.Returns(userId);
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(cipherId);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_WithOrgId_ChecksOrgCipher(
Guid userId, Guid cipherId, Guid orgId, Cipher cipher, CurrentContextOrganization org)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
_currentContext.GetOrganization(orgId).Returns(org);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.Received(1).GetByIdAsync(cipherId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(
tuples => tuples.First().Item1 == cipher
)
);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_WithNullCipher_OrgCipherNotFound_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId)
{
_currentContext.UserId.Returns(userId);
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns((CipherDetails?)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _cipherRepository.Received(1).GetByIdAsync(cipherId);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_CipherDoesNotBelongToOrg_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId, Guid differentOrgId, Cipher cipher)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = differentOrgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_CipherEvent_OrgNotFound_SkipsEvent(
Guid userId, Guid cipherId, Guid orgId, Cipher cipher)
{
_currentContext.UserId.Returns(userId);
cipher.Id = cipherId;
cipher.OrganizationId = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns((CipherDetails?)null);
_cipherRepository.GetByIdAsync(cipherId).Returns(cipher);
_currentContext.GetOrganization(orgId).Returns((CurrentContextOrganization)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogCipherEventsAsync(default);
}
[Theory]
[AutoData]
public async Task Post_MultipleCipherEvents_WithSameCipherId_UsesCachedCipher(
Guid userId, Guid cipherId, CipherDetails cipherDetails)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Cipher_ClientViewed,
CipherId = cipherId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _cipherRepository.Received(1).GetByIdAsync(cipherId, userId);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 2)
);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithValidOrg_LogsOrgEvent(
Guid userId, Guid orgId, Organization organization)
{
_currentContext.UserId.Returns(userId);
organization.Id = orgId;
_organizationRepository.GetByIdAsync(orgId).Returns(organization);
var eventDate = DateTime.UtcNow;
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = eventDate
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.Received(1).GetByIdAsync(orgId);
await _eventService.Received(1).LogOrganizationEventAsync(organization, EventType.Organization_ClientExportedVault, eventDate);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithoutOrgId_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = null,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
await _eventService.DidNotReceiveWithAnyArgs().LogOrganizationEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_OrganizationClientExportedVault_WithNullOrg_SkipsEvent(Guid userId, Guid orgId)
{
_currentContext.UserId.Returns(userId);
_organizationRepository.GetByIdAsync(orgId).Returns((Organization)null);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _organizationRepository.Received(1).GetByIdAsync(orgId);
await _eventService.DidNotReceiveWithAnyArgs().LogOrganizationEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_UnsupportedEventType_SkipsEvent(Guid userId)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_LoggedIn,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.DidNotReceiveWithAnyArgs().LogUserEventAsync(default, default, default);
}
[Theory]
[AutoData]
public async Task Post_MixedEventTypes_ProcessesAllEvents(
Guid userId, Guid cipherId, Guid orgId, CipherDetails cipherDetails, Organization organization)
{
_currentContext.UserId.Returns(userId);
cipherDetails.Id = cipherId;
organization.Id = orgId;
_cipherRepository.GetByIdAsync(cipherId, userId).Returns(cipherDetails);
_organizationRepository.GetByIdAsync(orgId).Returns(organization);
var events = new List<EventModel>
{
new EventModel
{
Type = EventType.User_ClientExportedVault,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipherId,
Date = DateTime.UtcNow
},
new EventModel
{
Type = EventType.Organization_ClientExportedVault,
OrganizationId = orgId,
Date = DateTime.UtcNow
}
};
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogUserEventAsync(userId, EventType.User_ClientExportedVault, Arg.Any<DateTime?>());
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 1)
);
await _eventService.Received(1).LogOrganizationEventAsync(organization, EventType.Organization_ClientExportedVault, Arg.Any<DateTime?>());
}
[Theory]
[AutoData]
public async Task Post_MoreThan50CipherEvents_LogsInBatches(Guid userId, List<CipherDetails> ciphers)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>();
for (int i = 0; i < 100; i++)
{
var cipher = ciphers[i % ciphers.Count];
_cipherRepository.GetByIdAsync(cipher.Id, userId).Returns(cipher);
events.Add(new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow
});
}
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(2).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 50)
);
}
[Theory]
[AutoData]
public async Task Post_Exactly50CipherEvents_LogsInSingleBatch(Guid userId, List<CipherDetails> ciphers)
{
_currentContext.UserId.Returns(userId);
var events = new List<EventModel>();
for (int i = 0; i < 50; i++)
{
var cipher = ciphers[i % ciphers.Count];
_cipherRepository.GetByIdAsync(cipher.Id, userId).Returns(cipher);
events.Add(new EventModel
{
Type = EventType.Cipher_ClientAutofilled,
CipherId = cipher.Id,
Date = DateTime.UtcNow
});
}
var result = await _sut.Post(events);
Assert.IsType<OkResult>(result);
await _eventService.Received(1).LogCipherEventsAsync(
Arg.Is<IEnumerable<Tuple<Cipher, EventType, DateTime?>>>(tuples => tuples.Count() == 50)
);
}
}

View File

@@ -9,14 +9,18 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="NSubstitute" Version="$(NSubstituteVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio"
Version="$(XUnitRunnerVisualStudioVersion)">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
<PackageReference Include="AutoFixture.Xunit2" Version="$(AutoFixtureXUnit2Version)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Events\Events.csproj" />
<ProjectReference Include="..\..\src\Core\Core.csproj" />
<ProjectReference Include="..\Common\Common.csproj" />
</ItemGroup>
</Project>

View File

@@ -1,10 +0,0 @@
namespace Events.Test;
// Delete this file once you have real tests
public class PlaceholderUnitTest
{
[Fact]
public void Test1()
{
}
}

View File

@@ -1,6 +1,7 @@
using System.Reflection;
using AutoFixture;
using AutoFixture.Xunit2;
using Bit.Identity.IdentityServer;
using Duende.IdentityServer.Validation;
namespace Bit.Identity.Test.AutoFixture;
@@ -8,7 +9,8 @@ namespace Bit.Identity.Test.AutoFixture;
internal class ValidatedTokenRequestCustomization : ICustomization
{
public ValidatedTokenRequestCustomization()
{ }
{
}
public void Customize(IFixture fixture)
{
@@ -22,10 +24,45 @@ internal class ValidatedTokenRequestCustomization : ICustomization
public class ValidatedTokenRequestAttribute : CustomizeAttribute
{
public ValidatedTokenRequestAttribute()
{ }
{
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new ValidatedTokenRequestCustomization();
}
}
internal class CustomValidatorRequestContextCustomization : ICustomization
{
public CustomValidatorRequestContextCustomization()
{
}
/// <summary>
/// Specific context members like <see cref="CustomValidatorRequestContext.RememberMeRequested" />,
/// <see cref="CustomValidatorRequestContext.TwoFactorRecoveryRequested"/>, and
/// <see cref="CustomValidatorRequestContext.SsoRequired" /> should initialize false,
/// and are made truthy in context upon evaluation of a request. Do not allow AutoFixture to eagerly make these
/// truthy; that is the responsibility of the <see cref="Bit.Identity.IdentityServer.RequestValidators.BaseRequestValidator{T}" />
/// </summary>
public void Customize(IFixture fixture)
{
fixture.Customize<CustomValidatorRequestContext>(composer => composer
.With(o => o.RememberMeRequested, false)
.With(o => o.TwoFactorRecoveryRequested, false)
.With(o => o.SsoRequired, false));
}
}
public class CustomValidatorRequestContextAttribute : CustomizeAttribute
{
public CustomValidatorRequestContextAttribute()
{
}
public override ICustomization GetCustomization(ParameterInfo parameter)
{
return new CustomValidatorRequestContextCustomization();
}
}

View File

@@ -100,19 +100,30 @@ public class BaseRequestValidatorTests
_userAccountKeysQuery);
}
private void SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(bool recoveryCodeSupportEnabled)
{
_featureService
.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)
.Returns(recoveryCodeSupportEnabled);
}
/* Logic path
* ValidateAsync -> UpdateFailedAuthDetailsAsync -> _mailService.SendFailedLoginAttemptsEmailAsync
* |-> BuildErrorResultAsync -> _eventService.LogUserEventAsync
* (self hosted) |-> _logger.LogWarning()
* |-> SetErrorResult
*/
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ContextNotValid_SelfHosted_ShouldBuildErrorResult_ShouldLogWarning(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_globalSettings.SelfHosted = true;
_sut.isValid = false;
@@ -122,18 +133,23 @@ public class BaseRequestValidatorTests
// Assert
var logs = _logger.Collector.GetSnapshot(true);
Assert.Contains(logs, l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
Assert.Contains(logs,
l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: ");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_DeviceNotValidated_ShouldLogError(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -141,14 +157,15 @@ public class BaseRequestValidatorTests
// 2 -> will result to false with no extra configuration
// 3 -> set two factor to be false
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to fail
requestContext.KnownDevice = false;
tokenRequest.GrantType = "password";
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(false));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
@@ -163,13 +180,17 @@ public class BaseRequestValidatorTests
.LogUserEventAsync(context.CustomValidatorRequestContext.User.Id, EventType.User_FailedLogIn);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_DeviceValidated_ShouldSucceed(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -177,12 +198,13 @@ public class BaseRequestValidatorTests
// 2 -> will result to false with no extra configuration
// 3 -> set two factor to be false
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
.RequiresTwoFactorAsync(Arg.Any<User>(), tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to pass
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// 5 -> not legacy user
_userService.IsLegacyUser(Arg.Any<string>())
@@ -202,13 +224,17 @@ public class BaseRequestValidatorTests
Assert.False(context.GrantResult.IsError);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ValidatedAuthRequest_ConsumedOnSuccess(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -235,7 +261,8 @@ public class BaseRequestValidatorTests
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
// 4 -> set up device validator to pass
_deviceValidator.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
_deviceValidator
.ValidateRequestDeviceAsync(Arg.Any<ValidatedTokenRequest>(), Arg.Any<CustomValidatorRequestContext>())
.Returns(Task.FromResult(true));
// 5 -> not legacy user
@@ -260,13 +287,17 @@ public class BaseRequestValidatorTests
ar.AuthenticationDate.HasValue));
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_ValidatedAuthRequest_NotConsumed_When2faRequired(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
// 1 -> to pass
_sut.isValid = true;
@@ -302,13 +333,17 @@ public class BaseRequestValidatorTests
await _authRequestRepository.DidNotReceive().ReplaceAsync(Arg.Any<AuthRequest>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_TwoFactorTokenInvalid_ShouldSendFailedTwoFactorEmail(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -345,13 +380,17 @@ public class BaseRequestValidatorTests
Arg.Any<string>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_TwoFactorRememberTokenExpired_ShouldNotSendFailedTwoFactorEmail(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
@@ -391,28 +430,34 @@ public class BaseRequestValidatorTests
// Assert
// Verify that the failed 2FA email was NOT sent for remember token expiration
await _mailService.DidNotReceive()
.SendFailedTwoFactorAttemptEmailAsync(Arg.Any<string>(), Arg.Any<TwoFactorProviderType>(), Arg.Any<DateTime>(), Arg.Any<string>());
.SendFailedTwoFactorAttemptEmailAsync(Arg.Any<string>(), Arg.Any<TwoFactorProviderType>(),
Arg.Any<DateTime>(), Arg.Any<string>());
}
// Test grantTypes that require SSO when a user is in an organization that requires it
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_OrgSsoRequiredTrue_ShouldSetSsoResult(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
context.ValidatedTokenRequest.GrantType = grantType;
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
@@ -425,16 +470,21 @@ public class BaseRequestValidatorTests
// Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredTrue_ShouldSetSsoResult(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -449,23 +499,28 @@ public class BaseRequestValidatorTests
// Assert
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
Assert.Equal("SSO authentication is required.", errorResponse.Message);
}
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -500,24 +555,29 @@ public class BaseRequestValidatorTests
// Test grantTypes where SSO would be required but the user is not in an
// organization that requires it
[Theory]
[BitAutoData("password")]
[BitAutoData("webauthn")]
[BitAutoData("refresh_token")]
[BitAutoData("password", true)]
[BitAutoData("password", false)]
[BitAutoData("webauthn", true)]
[BitAutoData("webauthn", false)]
[BitAutoData("refresh_token", true)]
[BitAutoData("refresh_token", false)]
public async Task ValidateAsync_GrantTypes_OrgSsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
context.ValidatedTokenRequest.GrantType = grantType;
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
_twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(false, null)));
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
@@ -540,20 +600,23 @@ public class BaseRequestValidatorTests
await _userRepository.Received(1).ReplaceAsync(Arg.Any<User>());
Assert.False(context.GrantResult.IsError);
}
// Test the grantTypes where SSO is in progress or not relevant
[Theory]
[BitAutoData("authorization_code")]
[BitAutoData("client_credentials")]
[BitAutoData("authorization_code", true)]
[BitAutoData("authorization_code", false)]
[BitAutoData("client_credentials", true)]
[BitAutoData("client_credentials", false)]
public async Task ValidateAsync_GrantTypes_SsoRequiredFalse_ShouldSucceed(
string grantType,
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
_sut.isValid = true;
@@ -577,7 +640,7 @@ public class BaseRequestValidatorTests
// Assert
await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed);
await _eventService.Received(1).LogUserEventAsync(
context.CustomValidatorRequestContext.User.Id, EventType.User_LoggedIn);
await _userRepository.Received(1).ReplaceAsync(Arg.Any<User>());
@@ -588,13 +651,17 @@ public class BaseRequestValidatorTests
/* Logic Path
* ValidateAsync -> UserService.IsLegacyUser -> FailAuthForLegacyUserAsync
*/
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_IsLegacyUser_FailAuthForLegacyUserAsync(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = context.CustomValidatorRequestContext.User;
user.Key = null;
@@ -613,21 +680,27 @@ public class BaseRequestValidatorTests
// Assert
Assert.True(context.GrantResult.IsError);
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
var expectedMessage = "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
var expectedMessage =
"Legacy encryption without a userkey is no longer supported. To recover your account, please contact support";
Assert.Equal(expectedMessage, errorResponse.Message);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_NoMasterPassword_ShouldSetUserDecryptionOptions(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = false,
@@ -663,19 +736,24 @@ public class BaseRequestValidatorTests
}
[Theory]
[BitAutoData(KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(KdfType.Argon2id, 11, 128, 5)]
[BitAutoData(true, KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(false, KdfType.PBKDF2_SHA256, 654_321, null, null)]
[BitAutoData(true, KdfType.Argon2id, 11, 128, 5)]
[BitAutoData(false, KdfType.Argon2id, 11, 128, 5)]
public async Task ValidateAsync_CustomResponse_MasterPassword_ShouldSetUserDecryptionOptions(
bool featureFlagValue,
KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = true,
@@ -728,13 +806,17 @@ public class BaseRequestValidatorTests
Assert.Equal("test@example.com", userDecryptionOptions.MasterPasswordUnlock.Salt);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_ShouldIncludeAccountKeys(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var mockAccountKeys = new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
@@ -747,11 +829,7 @@ public class BaseRequestValidatorTests
"test-wrapped-signing-key",
"test-verifying-key"
),
SecurityStateData = new SecurityStateData
{
SecurityState = "test-security-state",
SecurityVersion = 2
}
SecurityStateData = new SecurityStateData { SecurityState = "test-security-state", SecurityVersion = 2 }
};
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(mockAccountKeys);
@@ -759,7 +837,8 @@ public class BaseRequestValidatorTests
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions
{
HasMasterPassword = true,
@@ -808,13 +887,18 @@ public class BaseRequestValidatorTests
Assert.Equal("test-security-state", accountKeysResponse.SecurityState.SecurityState);
Assert.Equal(2, accountKeysResponse.SecurityState.SecurityVersion);
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_AccountKeysQuery_SkippedWhenPrivateKeyIsNull(
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
requestContext.User.PrivateKey = null;
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -833,13 +917,18 @@ public class BaseRequestValidatorTests
// Verify that the account keys query wasn't called.
await _userAccountKeysQuery.Received(0).Run(Arg.Any<User>());
}
[Theory, BitAutoData]
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_CustomResponse_AccountKeysQuery_CalledWithCorrectUser(
bool featureFlagValue,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue);
var expectedUser = requestContext.User;
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
@@ -853,7 +942,8 @@ public class BaseRequestValidatorTests
_userDecryptionOptionsBuilder.ForUser(Arg.Any<User>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithDevice(Arg.Any<Device>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithSso(Arg.Any<SsoConfig>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>()).Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.WithWebAuthnLoginCredential(Arg.Any<WebAuthnCredential>())
.Returns(_userDecryptionOptionsBuilder);
_userDecryptionOptionsBuilder.BuildAsync().Returns(Task.FromResult(new UserDecryptionOptions()));
var context = CreateContext(tokenRequest, requestContext, grantResult);
@@ -874,6 +964,285 @@ public class BaseRequestValidatorTests
await _userAccountKeysQuery.Received(1).Run(Arg.Is<User>(u => u.Id == expectedUser.Id));
}
/// <summary>
/// Tests the core PM-21153 feature: SSO-required users can use recovery codes to disable 2FA,
/// but must then authenticate via SSO with a descriptive message about the recovery.
/// This test validates:
/// 1. Validation order is changed (2FA before SSO) when recovery code is provided
/// 2. Recovery code successfully validates and sets TwoFactorRecoveryRequested flag
/// 3. SSO validation then fails with recovery-specific message
/// 4. User is NOT logged in (must authenticate via IdP)
/// </summary>
[Theory]
[BitAutoData(true)] // Feature flag ON - new behavior
[BitAutoData(false)] // Feature flag OFF - should fail at SSO before 2FA recovery
public async Task ValidateAsync_RecoveryCodeForSsoRequiredUser_BlocksWithDescriptiveMessage(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// Reset state that AutoFixture may have populated
requestContext.TwoFactorRecoveryRequested = false;
requestContext.RememberMeRequested = false;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is required (this user is in an org that requires SSO)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE (this triggers the special validation order)
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-12345";
// 5. Recovery code is valid (UserService.RecoverTwoFactorAsync will be called internally)
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-12345")
.Returns(Task.FromResult(true));
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
if (featureFlagEnabled)
{
// NEW BEHAVIOR: Recovery succeeds, then SSO blocks with descriptive message
Assert.Equal(
"Two-factor recovery has been performed. SSO authentication is required.",
errorResponse.Message);
// Verify recovery was marked
Assert.True(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested flag should be set");
}
else
{
// LEGACY BEHAVIOR: SSO blocks BEFORE recovery can happen
Assert.Equal(
"SSO authentication is required.",
errorResponse.Message);
// Recovery never happened because SSO checked first
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false (SSO blocked first)");
}
// In both cases: User is NOT logged in
await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn);
}
/// <summary>
/// Tests that validation order changes when a recovery code is PROVIDED (even if invalid).
/// This ensures the RecoveryCodeRequestForSsoRequiredUserScenario() logic is based on
/// request structure, not validation outcome. An SSO-required user who provides an
/// INVALID recovery code should:
/// 1. Have 2FA validated BEFORE SSO (new order)
/// 2. Get a 2FA error (invalid token)
/// 3. NOT get the recovery-specific SSO message (because recovery didn't complete)
/// 4. NOT be logged in
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_InvalidRecoveryCodeForSsoRequiredUser_FailsAt2FA(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is required
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(true));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE (triggers validation order change)
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "INVALID-recovery-code";
// 5. Recovery code is INVALID
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "INVALID-recovery-code")
.Returns(Task.FromResult(false));
// 6. Setup for failed 2FA email (if feature flag enabled)
_featureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail).Returns(true);
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code");
var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"];
if (featureFlagEnabled)
{
// NEW BEHAVIOR: 2FA is checked first (due to recovery code request), fails with 2FA error
Assert.Equal(
"Two-step token is invalid. Try again.",
errorResponse.Message);
// Recovery was attempted but failed - flag should NOT be set
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false (recovery failed)");
// Verify failed 2FA email was sent
await _mailService.Received(1).SendFailedTwoFactorAttemptEmailAsync(
user.Email,
TwoFactorProviderType.RecoveryCode,
Arg.Any<DateTime>(),
Arg.Any<string>());
// Verify failed login event was logged
await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_FailedLogIn2fa);
}
else
{
// LEGACY BEHAVIOR: SSO is checked first, blocks before 2FA
Assert.Equal(
"SSO authentication is required.",
errorResponse.Message);
// 2FA validation never happened
await _mailService.DidNotReceive().SendFailedTwoFactorAttemptEmailAsync(
Arg.Any<string>(),
Arg.Any<TwoFactorProviderType>(),
Arg.Any<DateTime>(),
Arg.Any<string>());
}
// In both cases: User is NOT logged in
await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn);
// Verify user failed login count was updated (in new behavior path)
if (featureFlagEnabled)
{
await _userRepository.Received(1).ReplaceAsync(Arg.Is<User>(u =>
u.Id == user.Id && u.FailedLoginCount > 0));
}
}
/// <summary>
/// Tests that non-SSO users can successfully use recovery codes to disable 2FA and log in.
/// This validates:
/// 1. Validation order changes to 2FA-first when recovery code is provided
/// 2. Recovery code validates successfully
/// 3. SSO check passes (user not in SSO-required org)
/// 4. User successfully logs in
/// 5. TwoFactorRecoveryRequested flag is set (for logging/audit purposes)
/// This is the "happy path" for recovery code usage.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public async Task ValidateAsync_RecoveryCodeForNonSsoUser_SuccessfulLogin(
bool featureFlagEnabled,
[AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest,
[AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext,
GrantValidationResult grantResult)
{
// Arrange
SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled);
var context = CreateContext(tokenRequest, requestContext, grantResult);
var user = requestContext.User;
// 1. Master password is valid
_sut.isValid = true;
// 2. SSO is NOT required (this is a regular user, not in SSO org)
_policyService.AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed)
.Returns(Task.FromResult(false));
// 3. 2FA is required
_twoFactorAuthenticationValidator
.RequiresTwoFactorAsync(user, tokenRequest)
.Returns(Task.FromResult(new Tuple<bool, Organization>(true, null)));
// 4. Provide a RECOVERY CODE
tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString();
tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code-67890";
// 5. Recovery code is valid
_twoFactorAuthenticationValidator
.VerifyTwoFactorAsync(user, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code-67890")
.Returns(Task.FromResult(true));
// 6. Device validation passes
_deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext)
.Returns(Task.FromResult(true));
// 7. User is not legacy
_userService.IsLegacyUser(Arg.Any<string>())
.Returns(false);
// 8. Setup user account keys for successful login response
_userAccountKeysQuery.Run(Arg.Any<User>()).Returns(new UserAccountKeysData
{
PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData(
"test-private-key",
"test-public-key"
)
});
// Act
await _sut.ValidateAsync(context);
// Assert
Assert.False(context.GrantResult.IsError, "Authentication should succeed for non-SSO user with valid recovery code");
// Verify user successfully logged in
await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_LoggedIn);
// Verify failed login count was reset (successful login)
await _userRepository.Received(1).ReplaceAsync(Arg.Is<User>(u =>
u.Id == user.Id && u.FailedLoginCount == 0));
if (featureFlagEnabled)
{
// NEW BEHAVIOR: Recovery flag should be set for audit purposes
Assert.True(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested flag should be set for audit/logging");
}
else
{
// LEGACY BEHAVIOR: Recovery flag doesn't exist, but login still succeeds
// (SSO check happens before 2FA in legacy, but user is not SSO-required so both pass)
Assert.False(requestContext.TwoFactorRecoveryRequested,
"TwoFactorRecoveryRequested should be false in legacy mode");
}
}
private BaseRequestValidationContextFake CreateContext(
ValidatedTokenRequest tokenRequest,
CustomValidatorRequestContext requestContext,

View File

@@ -1,6 +1,7 @@
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Dirt.Entities;
using Bit.Core.Dirt.Reports.Models.Data;
using Bit.Core.Dirt.Repositories;
using Bit.Core.Repositories;
using Bit.Core.Test.AutoFixture.Attributes;
@@ -49,6 +50,75 @@ public class OrganizationReportRepositoryTests
Assert.True(records.Count == 4);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task CreateAsync_ShouldPersistAllMetricProperties_WhenSet(
List<EntityFramework.Dirt.Repositories.OrganizationReportRepository> suts,
List<EfRepo.OrganizationRepository> efOrganizationRepos,
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange - Create a report with explicit metric values
var fixture = new Fixture();
var organization = fixture.Create<Organization>();
var report = fixture.Build<OrganizationReport>()
.With(x => x.ApplicationCount, 10)
.With(x => x.ApplicationAtRiskCount, 3)
.With(x => x.CriticalApplicationCount, 5)
.With(x => x.CriticalApplicationAtRiskCount, 2)
.With(x => x.MemberCount, 25)
.With(x => x.MemberAtRiskCount, 7)
.With(x => x.CriticalMemberCount, 12)
.With(x => x.CriticalMemberAtRiskCount, 4)
.With(x => x.PasswordCount, 100)
.With(x => x.PasswordAtRiskCount, 15)
.With(x => x.CriticalPasswordCount, 50)
.With(x => x.CriticalPasswordAtRiskCount, 8)
.Create();
var retrievedReports = new List<OrganizationReport>();
// Act & Assert - Test EF repositories
foreach (var sut in suts)
{
var i = suts.IndexOf(sut);
var efOrganization = await efOrganizationRepos[i].CreateAsync(organization);
sut.ClearChangeTracking();
report.OrganizationId = efOrganization.Id;
var createdReport = await sut.CreateAsync(report);
sut.ClearChangeTracking();
var savedReport = await sut.GetByIdAsync(createdReport.Id);
retrievedReports.Add(savedReport);
}
// Act & Assert - Test SQL repository
var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization);
report.OrganizationId = sqlOrganization.Id;
var sqlCreatedReport = await sqlOrganizationReportRepo.CreateAsync(report);
var savedSqlReport = await sqlOrganizationReportRepo.GetByIdAsync(sqlCreatedReport.Id);
retrievedReports.Add(savedSqlReport);
// Assert - Verify all metric properties are persisted correctly across all repositories
Assert.True(retrievedReports.Count == 4);
foreach (var retrievedReport in retrievedReports)
{
Assert.NotNull(retrievedReport);
Assert.Equal(10, retrievedReport.ApplicationCount);
Assert.Equal(3, retrievedReport.ApplicationAtRiskCount);
Assert.Equal(5, retrievedReport.CriticalApplicationCount);
Assert.Equal(2, retrievedReport.CriticalApplicationAtRiskCount);
Assert.Equal(25, retrievedReport.MemberCount);
Assert.Equal(7, retrievedReport.MemberAtRiskCount);
Assert.Equal(12, retrievedReport.CriticalMemberCount);
Assert.Equal(4, retrievedReport.CriticalMemberAtRiskCount);
Assert.Equal(100, retrievedReport.PasswordCount);
Assert.Equal(15, retrievedReport.PasswordAtRiskCount);
Assert.Equal(50, retrievedReport.CriticalPasswordCount);
Assert.Equal(8, retrievedReport.CriticalPasswordAtRiskCount);
}
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task RetrieveByOrganisation_Works(
OrganizationReportRepository sqlOrganizationReportRepo,
@@ -66,6 +136,67 @@ public class OrganizationReportRepositoryTests
Assert.Equal(secondOrg.Id, secondRetrievedReport.OrganizationId);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task UpdateAsync_ShouldUpdateAllMetricProperties_WhenChanged(
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange - Create initial report with specific metric values
var fixture = new Fixture();
var organization = fixture.Create<Organization>();
var org = await sqlOrganizationRepo.CreateAsync(organization);
var report = fixture.Build<OrganizationReport>()
.With(x => x.OrganizationId, org.Id)
.With(x => x.ApplicationCount, 10)
.With(x => x.ApplicationAtRiskCount, 3)
.With(x => x.CriticalApplicationCount, 5)
.With(x => x.CriticalApplicationAtRiskCount, 2)
.With(x => x.MemberCount, 25)
.With(x => x.MemberAtRiskCount, 7)
.With(x => x.CriticalMemberCount, 12)
.With(x => x.CriticalMemberAtRiskCount, 4)
.With(x => x.PasswordCount, 100)
.With(x => x.PasswordAtRiskCount, 15)
.With(x => x.CriticalPasswordCount, 50)
.With(x => x.CriticalPasswordAtRiskCount, 8)
.Create();
var createdReport = await sqlOrganizationReportRepo.CreateAsync(report);
// Act - Update all metric properties with new values
createdReport.ApplicationCount = 20;
createdReport.ApplicationAtRiskCount = 6;
createdReport.CriticalApplicationCount = 10;
createdReport.CriticalApplicationAtRiskCount = 4;
createdReport.MemberCount = 50;
createdReport.MemberAtRiskCount = 14;
createdReport.CriticalMemberCount = 24;
createdReport.CriticalMemberAtRiskCount = 8;
createdReport.PasswordCount = 200;
createdReport.PasswordAtRiskCount = 30;
createdReport.CriticalPasswordCount = 100;
createdReport.CriticalPasswordAtRiskCount = 16;
await sqlOrganizationReportRepo.UpsertAsync(createdReport);
// Assert - Verify all metric properties were updated correctly
var updatedReport = await sqlOrganizationReportRepo.GetByIdAsync(createdReport.Id);
Assert.NotNull(updatedReport);
Assert.Equal(20, updatedReport.ApplicationCount);
Assert.Equal(6, updatedReport.ApplicationAtRiskCount);
Assert.Equal(10, updatedReport.CriticalApplicationCount);
Assert.Equal(4, updatedReport.CriticalApplicationAtRiskCount);
Assert.Equal(50, updatedReport.MemberCount);
Assert.Equal(14, updatedReport.MemberAtRiskCount);
Assert.Equal(24, updatedReport.CriticalMemberCount);
Assert.Equal(8, updatedReport.CriticalMemberAtRiskCount);
Assert.Equal(200, updatedReport.PasswordCount);
Assert.Equal(30, updatedReport.PasswordAtRiskCount);
Assert.Equal(100, updatedReport.CriticalPasswordCount);
Assert.Equal(16, updatedReport.CriticalPasswordAtRiskCount);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task Delete_Works(
List<EntityFramework.Dirt.Repositories.OrganizationReportRepository> suts,
@@ -359,6 +490,49 @@ public class OrganizationReportRepositoryTests
Assert.Null(result);
}
[CiSkippedTheory, EfOrganizationReportAutoData]
public async Task UpdateMetricsAsync_ShouldUpdateMetricsCorrectly(
OrganizationReportRepository sqlOrganizationReportRepo,
SqlRepo.OrganizationRepository sqlOrganizationRepo)
{
// Arrange
var (org, report) = await CreateOrganizationAndReportAsync(sqlOrganizationRepo, sqlOrganizationReportRepo);
var metrics = new OrganizationReportMetricsData
{
ApplicationCount = 10,
ApplicationAtRiskCount = 2,
CriticalApplicationCount = 5,
CriticalApplicationAtRiskCount = 1,
MemberCount = 20,
MemberAtRiskCount = 4,
CriticalMemberCount = 10,
CriticalMemberAtRiskCount = 2,
PasswordCount = 100,
PasswordAtRiskCount = 15,
CriticalPasswordCount = 50,
CriticalPasswordAtRiskCount = 5
};
// Act
await sqlOrganizationReportRepo.UpdateMetricsAsync(report.Id, metrics);
var updatedReport = await sqlOrganizationReportRepo.GetByIdAsync(report.Id);
// Assert
Assert.Equal(metrics.ApplicationCount, updatedReport.ApplicationCount);
Assert.Equal(metrics.ApplicationAtRiskCount, updatedReport.ApplicationAtRiskCount);
Assert.Equal(metrics.CriticalApplicationCount, updatedReport.CriticalApplicationCount);
Assert.Equal(metrics.CriticalApplicationAtRiskCount, updatedReport.CriticalApplicationAtRiskCount);
Assert.Equal(metrics.MemberCount, updatedReport.MemberCount);
Assert.Equal(metrics.MemberAtRiskCount, updatedReport.MemberAtRiskCount);
Assert.Equal(metrics.CriticalMemberCount, updatedReport.CriticalMemberCount);
Assert.Equal(metrics.CriticalMemberAtRiskCount, updatedReport.CriticalMemberAtRiskCount);
Assert.Equal(metrics.PasswordCount, updatedReport.PasswordCount);
Assert.Equal(metrics.PasswordAtRiskCount, updatedReport.PasswordAtRiskCount);
Assert.Equal(metrics.CriticalPasswordCount, updatedReport.CriticalPasswordCount);
Assert.Equal(metrics.CriticalPasswordAtRiskCount, updatedReport.CriticalPasswordAtRiskCount);
}
private async Task<(Organization, OrganizationReport)> CreateOrganizationAndReportAsync(
IOrganizationRepository orgRepo,
IOrganizationReportRepository orgReportRepo)

View File

@@ -33,14 +33,69 @@ public static class OrganizationTestHelpers
public static Task<Organization> CreateTestOrganizationAsync(this IOrganizationRepository organizationRepository,
int? seatCount = null,
string identifier = "test")
=> organizationRepository.CreateAsync(new Organization
{
var id = Guid.NewGuid();
return organizationRepository.CreateAsync(new Organization
{
Name = $"{identifier}-{Guid.NewGuid()}",
BillingEmail = "billing@example.com", // TODO: EF does not enforce this being NOT NULL
Plan = "Enterprise (Annually)", // TODO: EF does not enforce this being NOT NULl
Name = $"{identifier}-{id}",
BillingEmail = $"billing-{id}@example.com",
Plan = "Enterprise (Annually)",
PlanType = PlanType.EnterpriseAnnually,
Seats = seatCount
Identifier = $"{identifier}-{id}",
BusinessName = $"Test Business {id}",
BusinessAddress1 = "123 Test Street",
BusinessAddress2 = "Suite 100",
BusinessAddress3 = "Building A",
BusinessCountry = "US",
BusinessTaxNumber = "123456789",
Seats = seatCount,
MaxCollections = 50,
UsePolicies = true,
UseSso = true,
UseKeyConnector = true,
UseScim = true,
UseGroups = true,
UseDirectory = true,
UseEvents = true,
UseTotp = true,
Use2fa = true,
UseApi = true,
UseResetPassword = true,
UseSecretsManager = true,
UsePasswordManager = true,
SelfHost = false,
UsersGetPremium = true,
UseCustomPermissions = true,
Storage = 1073741824, // 1 GB in bytes
MaxStorageGb = 10,
Gateway = GatewayType.Stripe,
GatewayCustomerId = $"cus_{id}",
GatewaySubscriptionId = $"sub_{id}",
ReferenceData = "{\"test\":\"data\"}",
Enabled = true,
LicenseKey = $"license-{id}",
PublicKey = "test-public-key",
PrivateKey = "test-private-key",
TwoFactorProviders = null,
ExpirationDate = DateTime.UtcNow.AddYears(1),
MaxAutoscaleSeats = 200,
OwnersNotifiedOfAutoscaling = null,
Status = OrganizationStatusType.Managed,
SmSeats = 50,
SmServiceAccounts = 25,
MaxAutoscaleSmSeats = 100,
MaxAutoscaleSmServiceAccounts = 50,
LimitCollectionCreation = true,
LimitCollectionDeletion = true,
LimitItemDeletion = true,
AllowAdminAccessToAllCollectionItems = true,
UseRiskInsights = true,
UseOrganizationDomains = true,
UseAdminSponsoredFamilies = true,
SyncSeats = false,
UseAutomaticUserConfirmation = true
});
}
/// <summary>
/// Creates a confirmed Owner for the specified organization and user.

View File

@@ -461,13 +461,7 @@ public class OrganizationUserRepositoryTests
KdfParallelism = 3
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL
Plan = "Test", // TODO: EF does not enforce this being NOT NULL
PrivateKey = "privatekey",
});
var organization = await organizationRepository.CreateTestOrganizationAsync();
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
@@ -536,9 +530,72 @@ public class OrganizationUserRepositoryTests
Assert.Equal(organization.SmServiceAccounts, result.SmServiceAccounts);
Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation);
Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion);
Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion);
Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems);
Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights);
Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains);
Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies);
Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation);
}
[Theory, DatabaseData]
public async Task GetManyDetailsByUserAsync_ShouldPopulateSsoPropertiesCorrectly(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ISsoConfigRepository ssoConfigRepository)
{
var user = await userRepository.CreateTestUserAsync();
var organizationWithSso = await organizationRepository.CreateTestOrganizationAsync();
var organizationWithoutSso = await organizationRepository.CreateTestOrganizationAsync();
var orgUserWithSso = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organizationWithSso.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
Email = user.Email
});
var orgUserWithoutSso = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organizationWithoutSso.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
Email = user.Email
});
// Create SSO configuration for first organization only
var serializedSsoConfigData = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize();
var ssoConfig = await ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organizationWithSso.Id,
Enabled = true,
Data = serializedSsoConfigData
});
var results = (await organizationUserRepository.GetManyDetailsByUserAsync(user.Id)).ToList();
Assert.Equal(2, results.Count);
var orgWithSsoDetails = results.Single(r => r.OrganizationId == organizationWithSso.Id);
var orgWithoutSsoDetails = results.Single(r => r.OrganizationId == organizationWithoutSso.Id);
// Organization with SSO should have SSO properties populated
Assert.True(orgWithSsoDetails.SsoEnabled);
Assert.NotNull(orgWithSsoDetails.SsoConfig);
Assert.Equal(serializedSsoConfigData, orgWithSsoDetails.SsoConfig);
// Organization without SSO should have null SSO properties
Assert.Null(orgWithoutSsoDetails.SsoEnabled);
Assert.Null(orgWithoutSsoDetails.SsoConfig);
}
[DatabaseTheory, DatabaseData]
@@ -1427,6 +1484,8 @@ public class OrganizationUserRepositoryTests
var organization = await organizationRepository.CreateTestOrganizationAsync();
var user = await userRepository.CreateTestUserAsync();
var orgUser = await organizationUserRepository.CreateAcceptedTestOrganizationUserAsync(organization, user);
const string key = "test-key";
orgUser.Key = key;
// Act
var result = await organizationUserRepository.ConfirmOrganizationUserAsync(orgUser);
@@ -1436,6 +1495,7 @@ public class OrganizationUserRepositoryTests
var updatedUser = await organizationUserRepository.GetByIdAsync(orgUser.Id);
Assert.NotNull(updatedUser);
Assert.Equal(OrganizationUserStatusType.Confirmed, updatedUser.Status);
Assert.Equal(key, updatedUser.Key);
// Annul
await organizationRepository.DeleteAsync(organization);

View File

@@ -0,0 +1,142 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Repositories;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories;
public class ProviderUserRepositoryTests
{
[Theory, DatabaseData]
public async Task GetManyOrganizationDetailsByUserAsync_ShouldPopulatePropertiesCorrectly(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository,
ISsoConfigRepository ssoConfigRepository)
{
var user = await userRepository.CreateTestUserAsync();
var organizationWithSso = await organizationRepository.CreateTestOrganizationAsync();
var organizationWithoutSso = await organizationRepository.CreateTestOrganizationAsync();
var provider = await providerRepository.CreateAsync(new Provider
{
Name = "Test Provider",
Enabled = true,
Type = ProviderType.Msp
});
var providerUser = await providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user.Id,
Status = ProviderUserStatusType.Confirmed,
Type = ProviderUserType.ProviderAdmin
});
var providerOrganizationWithSso = await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organizationWithSso.Id
});
var providerOrganizationWithoutSso = await providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organizationWithoutSso.Id
});
// Create SSO configuration for first organization only
var serializedSsoConfigData = new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.KeyConnector,
KeyConnectorUrl = "https://keyconnector.example.com"
}.Serialize();
var ssoConfig = await ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organizationWithSso.Id,
Enabled = true,
Data = serializedSsoConfigData
});
var results = (await providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed)).ToList();
Assert.Equal(2, results.Count);
var orgWithSsoDetails = results.Single(r => r.OrganizationId == organizationWithSso.Id);
var orgWithoutSsoDetails = results.Single(r => r.OrganizationId == organizationWithoutSso.Id);
// Verify all properties for both organizations
AssertProviderOrganizationDetails(orgWithSsoDetails, organizationWithSso, user, provider, providerUser);
AssertProviderOrganizationDetails(orgWithoutSsoDetails, organizationWithoutSso, user, provider, providerUser);
// Organization without SSO should have null SSO properties
Assert.Null(orgWithoutSsoDetails.SsoEnabled);
Assert.Null(orgWithoutSsoDetails.SsoConfig);
// Organization with SSO should have SSO properties populated
Assert.True(orgWithSsoDetails.SsoEnabled);
Assert.NotNull(orgWithSsoDetails.SsoConfig);
Assert.Equal(serializedSsoConfigData, orgWithSsoDetails.SsoConfig);
}
private static void AssertProviderOrganizationDetails(
ProviderUserOrganizationDetails actual,
Organization expectedOrganization,
User expectedUser,
Provider expectedProvider,
ProviderUser expectedProviderUser)
{
// Organization properties
Assert.Equal(expectedOrganization.Id, actual.OrganizationId);
Assert.Equal(expectedUser.Id, actual.UserId);
Assert.Equal(expectedOrganization.Name, actual.Name);
Assert.Equal(expectedOrganization.UsePolicies, actual.UsePolicies);
Assert.Equal(expectedOrganization.UseSso, actual.UseSso);
Assert.Equal(expectedOrganization.UseKeyConnector, actual.UseKeyConnector);
Assert.Equal(expectedOrganization.UseScim, actual.UseScim);
Assert.Equal(expectedOrganization.UseGroups, actual.UseGroups);
Assert.Equal(expectedOrganization.UseDirectory, actual.UseDirectory);
Assert.Equal(expectedOrganization.UseEvents, actual.UseEvents);
Assert.Equal(expectedOrganization.UseTotp, actual.UseTotp);
Assert.Equal(expectedOrganization.Use2fa, actual.Use2fa);
Assert.Equal(expectedOrganization.UseApi, actual.UseApi);
Assert.Equal(expectedOrganization.UseResetPassword, actual.UseResetPassword);
Assert.Equal(expectedOrganization.UsersGetPremium, actual.UsersGetPremium);
Assert.Equal(expectedOrganization.UseCustomPermissions, actual.UseCustomPermissions);
Assert.Equal(expectedOrganization.SelfHost, actual.SelfHost);
Assert.Equal(expectedOrganization.Seats, actual.Seats);
Assert.Equal(expectedOrganization.MaxCollections, actual.MaxCollections);
Assert.Equal(expectedOrganization.MaxStorageGb, actual.MaxStorageGb);
Assert.Equal(expectedOrganization.Identifier, actual.Identifier);
Assert.Equal(expectedOrganization.PublicKey, actual.PublicKey);
Assert.Equal(expectedOrganization.PrivateKey, actual.PrivateKey);
Assert.Equal(expectedOrganization.Enabled, actual.Enabled);
Assert.Equal(expectedOrganization.PlanType, actual.PlanType);
Assert.Equal(expectedOrganization.LimitCollectionCreation, actual.LimitCollectionCreation);
Assert.Equal(expectedOrganization.LimitCollectionDeletion, actual.LimitCollectionDeletion);
Assert.Equal(expectedOrganization.LimitItemDeletion, actual.LimitItemDeletion);
Assert.Equal(expectedOrganization.AllowAdminAccessToAllCollectionItems, actual.AllowAdminAccessToAllCollectionItems);
Assert.Equal(expectedOrganization.UseRiskInsights, actual.UseRiskInsights);
Assert.Equal(expectedOrganization.UseOrganizationDomains, actual.UseOrganizationDomains);
Assert.Equal(expectedOrganization.UseAdminSponsoredFamilies, actual.UseAdminSponsoredFamilies);
Assert.Equal(expectedOrganization.UseAutomaticUserConfirmation, actual.UseAutomaticUserConfirmation);
// Provider-specific properties
Assert.Equal(expectedProvider.Id, actual.ProviderId);
Assert.Equal(expectedProvider.Name, actual.ProviderName);
Assert.Equal(expectedProvider.Type, actual.ProviderType);
Assert.Equal(expectedProviderUser.Id, actual.ProviderUserId);
Assert.Equal(expectedProviderUser.Status, actual.Status);
Assert.Equal(expectedProviderUser.Type, actual.Type);
}
}

View File

@@ -345,4 +345,110 @@ public class SecurityTaskRepositoryTests
Assert.Equal(0, metrics.CompletedTasks);
Assert.Equal(0, metrics.TotalTasks);
}
[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_MarksPendingTasksAsCompleted(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "billing@email.com"
});
var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var task1 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
var task2 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id, cipher2.Id]);
var updatedTask1 = await securityTaskRepository.GetByIdAsync(task1.Id);
var updatedTask2 = await securityTaskRepository.GetByIdAsync(task2.Id);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask1.Status);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask2.Status);
}
[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_OnlyUpdatesSpecifiedCiphers(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "billing@email.com"
});
var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});
var taskToUpdate = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
var taskToKeep = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id]);
var updatedTask = await securityTaskRepository.GetByIdAsync(taskToUpdate.Id);
var unchangedTask = await securityTaskRepository.GetByIdAsync(taskToKeep.Id);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask.Status);
Assert.Equal(SecurityTaskStatus.Pending, unchangedTask.Status);
}
}

View File

@@ -1,6 +1,7 @@
using AspNetCoreRateLimit;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Platform.Mail.Delivery;
using Bit.Core.Platform.Push;
using Bit.Core.Platform.PushRegistration.Internal;
using Bit.Core.Repositories;