1
0
mirror of https://github.com/bitwarden/server synced 2026-02-01 01:03:25 +00:00

Merge branch 'main' into ac/pm-28842/cap-password-minimum-length

This commit is contained in:
Rui Tomé
2026-01-15 14:13:57 +00:00
committed by GitHub
318 changed files with 48786 additions and 2929 deletions

View File

@@ -4,7 +4,6 @@ using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.Models.Response;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
@@ -14,8 +13,6 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Xunit;
namespace Bit.Api.IntegrationTest.AdminConsole.Controllers;
@@ -32,12 +29,6 @@ public class OrganizationUserControllerBulkRevokeTests : IClassFixture<ApiApplic
public OrganizationUserControllerBulkRevokeTests(ApiApplicationFactory apiFactory)
{
_factory = apiFactory;
_factory.SubstituteService<IFeatureService>(featureService =>
{
featureService
.IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2)
.Returns(true);
});
_client = _factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}

View File

@@ -1,6 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,19 +1,28 @@
using System.Net;
using System.Text.Json;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.KeyManagement.Models.Requests;
using Bit.Api.Models.Response;
using Bit.Core;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Repositories;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using Xunit;
using static Bit.Core.KeyManagement.Enums.SignatureAlgorithm;
namespace Bit.Api.IntegrationTest.Controllers;
@@ -21,6 +30,8 @@ public class AccountsControllerTest : IClassFixture<ApiApplicationFactory>, IAsy
{
private static readonly string _masterKeyWrappedUserKey =
"2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=";
private static readonly string _mockEncryptedType7String = "7.AOs41Hd8OQiCPXjyJKCiDA==";
private static readonly string _mockEncryptedType7WrappedSigningKey = "7.DRv74Kg1RSlFSam1MNFlGD==";
private static readonly string _masterPasswordHash = "master_password_hash";
private static readonly string _newMasterPasswordHash = "new_master_password_hash";
@@ -35,6 +46,11 @@ public class AccountsControllerTest : IClassFixture<ApiApplicationFactory>, IAsy
private readonly IPushNotificationService _pushNotificationService;
private readonly IFeatureService _featureService;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository;
private readonly IEventRepository _eventRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private string _ownerEmail = null!;
@@ -49,6 +65,11 @@ public class AccountsControllerTest : IClassFixture<ApiApplicationFactory>, IAsy
_pushNotificationService = _factory.GetService<IPushNotificationService>();
_featureService = _factory.GetService<IFeatureService>();
_passwordHasher = _factory.GetService<IPasswordHasher<User>>();
_organizationRepository = _factory.GetService<IOrganizationRepository>();
_ssoConfigRepository = _factory.GetService<ISsoConfigRepository>();
_userSignatureKeyPairRepository = _factory.GetService<IUserSignatureKeyPairRepository>();
_eventRepository = _factory.GetService<IEventRepository>();
_organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
}
public async Task InitializeAsync()
@@ -435,4 +456,531 @@ public class AccountsControllerTest : IClassFixture<ApiApplicationFactory>, IAsy
message.Content = JsonContent.Create(requestModel);
return await _client.SendAsync(message);
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V1_MasterPasswordDecryption_Success(string organizationSsoIdentifier)
{
// Arrange - Create organization and user
var ownerEmail = $"owner-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(ownerEmail);
var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory,
ownerEmail: ownerEmail,
name: "Test Org V1");
organization.UseSso = true;
organization.Identifier = organizationSsoIdentifier;
await _organizationRepository.ReplaceAsync(organization);
await _ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organization.Id,
Enabled = true,
Data = JsonSerializer.Serialize(new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.MasterPassword,
}, JsonHelpers.CamelCase),
});
// Create user with password initially, so we can login
var userEmail = $"user-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(userEmail);
// Add user to organization
var user = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(user);
await OrganizationTestHelpers.CreateUserAsync(_factory, organization.Id, userEmail,
OrganizationUserType.User, userStatusType: OrganizationUserStatusType.Invited);
// Login as the user
await _loginHelper.LoginAsync(userEmail);
// Remove the master password and keys to simulate newly registered SSO user
user.MasterPassword = null;
user.Key = null;
user.PrivateKey = null;
user.PublicKey = null;
await _userRepository.ReplaceAsync(user);
// V1 (Obsolete) request format - to be removed with PM-27327
var request = new
{
masterPasswordHash = _newMasterPasswordHash,
key = _masterKeyWrappedUserKey,
keys = new
{
publicKey = "v1-publicKey",
encryptedPrivateKey = "v1-encryptedPrivateKey"
},
kdf = 0, // PBKDF2_SHA256
kdfIterations = 600000,
kdfMemory = (int?)null,
kdfParallelism = (int?)null,
masterPasswordHint = "v1-integration-test-hint",
orgIdentifier = organization.Identifier
};
var jsonRequest = JsonSerializer.Serialize(request, JsonHelpers.CamelCase);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
Assert.Fail($"Expected success but got {response.StatusCode}. Error: {errorContent}");
}
// Verify user in database
var updatedUser = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(updatedUser);
Assert.Equal("v1-integration-test-hint", updatedUser.MasterPasswordHint);
// Verify the master password is hashed and stored
Assert.NotNull(updatedUser.MasterPassword);
var verificationResult = _passwordHasher.VerifyHashedPassword(updatedUser, updatedUser.MasterPassword, _newMasterPasswordHash);
Assert.Equal(PasswordVerificationResult.Success, verificationResult);
// Verify KDF settings
Assert.Equal(KdfType.PBKDF2_SHA256, updatedUser.Kdf);
Assert.Equal(600_000, updatedUser.KdfIterations);
Assert.Null(updatedUser.KdfMemory);
Assert.Null(updatedUser.KdfParallelism);
// Verify timestamps are updated
Assert.Equal(DateTime.UtcNow, updatedUser.RevisionDate, TimeSpan.FromMinutes(1));
Assert.Equal(DateTime.UtcNow, updatedUser.AccountRevisionDate, TimeSpan.FromMinutes(1));
// Verify keys are set (V1 uses Keys property)
Assert.Equal(_masterKeyWrappedUserKey, updatedUser.Key);
Assert.Equal("v1-publicKey", updatedUser.PublicKey);
Assert.Equal("v1-encryptedPrivateKey", updatedUser.PrivateKey);
// Verify User_ChangedPassword event was logged
var events = await _eventRepository.GetManyByUserAsync(updatedUser.Id, DateTime.UtcNow.AddMinutes(-5), DateTime.UtcNow.AddMinutes(1), new PageOptions { PageSize = 100 });
Assert.NotNull(events);
Assert.Contains(events.Data, e => e.Type == EventType.User_ChangedPassword && e.UserId == updatedUser.Id);
// Verify user was accepted into the organization
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(updatedUser.Id);
var orgUser = orgUsers.FirstOrDefault(ou => ou.OrganizationId == organization.Id);
Assert.NotNull(orgUser);
Assert.Equal(OrganizationUserStatusType.Accepted, orgUser.Status);
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_MasterPasswordDecryption_Success(string organizationSsoIdentifier)
{
// Arrange - Create organization and user
var ownerEmail = $"owner-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(ownerEmail);
var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory,
ownerEmail: ownerEmail,
name: "Test Org");
organization.UseSso = true;
organization.Identifier = organizationSsoIdentifier;
await _organizationRepository.ReplaceAsync(organization);
await _ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organization.Id,
Enabled = true,
Data = JsonSerializer.Serialize(new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.MasterPassword,
}, JsonHelpers.CamelCase),
});
// Create user with password initially, so we can login
var userEmail = $"user-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(userEmail);
// Add user to organization
var user = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(user);
await OrganizationTestHelpers.CreateUserAsync(_factory, organization.Id, userEmail,
OrganizationUserType.User, userStatusType: OrganizationUserStatusType.Invited);
// Login as the user
await _loginHelper.LoginAsync(userEmail);
// Remove the master password and keys to simulate newly registered SSO user
user.MasterPassword = null;
user.Key = null;
user.PrivateKey = null;
user.PublicKey = null;
user.SignedPublicKey = null;
await _userRepository.ReplaceAsync(user);
var jsonRequest = CreateV2SetPasswordRequestJson(
userEmail,
organization.Identifier,
"integration-test-hint",
includeAccountKeys: true);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
Assert.Fail($"Expected success but got {response.StatusCode}. Error: {errorContent}");
}
// Verify user in database
var updatedUser = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(updatedUser);
Assert.Equal("integration-test-hint", updatedUser.MasterPasswordHint);
// Verify the master password is hashed and stored
Assert.NotNull(updatedUser.MasterPassword);
var verificationResult = _passwordHasher.VerifyHashedPassword(updatedUser, updatedUser.MasterPassword, _newMasterPasswordHash);
Assert.Equal(PasswordVerificationResult.Success, verificationResult);
// Verify KDF settings
Assert.Equal(KdfType.PBKDF2_SHA256, updatedUser.Kdf);
Assert.Equal(600_000, updatedUser.KdfIterations);
Assert.Null(updatedUser.KdfMemory);
Assert.Null(updatedUser.KdfParallelism);
// Verify timestamps are updated
Assert.Equal(DateTime.UtcNow, updatedUser.RevisionDate, TimeSpan.FromMinutes(1));
Assert.Equal(DateTime.UtcNow, updatedUser.AccountRevisionDate, TimeSpan.FromMinutes(1));
// Verify keys are set
Assert.Equal(_masterKeyWrappedUserKey, updatedUser.Key);
Assert.Equal("publicKey", updatedUser.PublicKey);
Assert.Equal(_mockEncryptedType7String, updatedUser.PrivateKey);
Assert.Equal("signedPublicKey", updatedUser.SignedPublicKey);
// Verify security state
Assert.Equal(2, updatedUser.SecurityVersion);
Assert.Equal("v2", updatedUser.SecurityState);
// Verify signature key pair data
var signatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(updatedUser.Id);
Assert.NotNull(signatureKeyPair);
Assert.Equal(Ed25519, signatureKeyPair.SignatureAlgorithm);
Assert.Equal(_mockEncryptedType7WrappedSigningKey, signatureKeyPair.WrappedSigningKey);
Assert.Equal("verifyingKey", signatureKeyPair.VerifyingKey);
// Verify User_ChangedPassword event was logged
var events = await _eventRepository.GetManyByUserAsync(updatedUser.Id, DateTime.UtcNow.AddMinutes(-5), DateTime.UtcNow.AddMinutes(1), new PageOptions { PageSize = 100 });
Assert.NotNull(events);
Assert.Contains(events.Data, e => e.Type == EventType.User_ChangedPassword && e.UserId == updatedUser.Id);
// Verify user was accepted into the organization
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(updatedUser.Id);
var orgUser = orgUsers.FirstOrDefault(ou => ou.OrganizationId == organization.Id);
Assert.NotNull(orgUser);
Assert.Equal(OrganizationUserStatusType.Accepted, orgUser.Status);
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_TDEDecryption_Success(string organizationSsoIdentifier)
{
// Arrange - Create organization with TDE
var ownerEmail = $"owner-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(ownerEmail);
var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory,
ownerEmail: ownerEmail,
name: "Test Org TDE");
organization.UseSso = true;
organization.Identifier = organizationSsoIdentifier;
await _organizationRepository.ReplaceAsync(organization);
// Configure SSO for TDE (TrustedDeviceEncryption)
await _ssoConfigRepository.CreateAsync(new SsoConfig
{
OrganizationId = organization.Id,
Enabled = true,
Data = JsonSerializer.Serialize(new SsoConfigurationData
{
MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption,
}, JsonHelpers.CamelCase),
});
// Create user with password initially, so we can login
var userEmail = $"user-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(userEmail);
var user = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(user);
// Add user to organization and confirm them (TDE users are confirmed, not invited)
await OrganizationTestHelpers.CreateUserAsync(_factory, organization.Id, userEmail,
OrganizationUserType.User, userStatusType: OrganizationUserStatusType.Confirmed);
// Login as the user
await _loginHelper.LoginAsync(userEmail);
// Set up TDE user with V2 account keys but no master password
// TDE users already have their account keys from device provisioning
user.MasterPassword = null;
user.Key = null;
user.PublicKey = "tde-publicKey";
user.PrivateKey = _mockEncryptedType7String;
user.SignedPublicKey = "tde-signedPublicKey";
user.SecurityVersion = 2;
user.SecurityState = "v2-tde";
await _userRepository.ReplaceAsync(user);
// Create signature key pair for TDE user
var signatureKeyPairData = new Core.KeyManagement.Models.Data.SignatureKeyPairData(
Ed25519,
_mockEncryptedType7WrappedSigningKey,
"tde-verifyingKey");
var setSignatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(user.Id);
if (setSignatureKeyPair == null)
{
var newKeyPair = new Core.KeyManagement.Entities.UserSignatureKeyPair
{
UserId = user.Id,
SignatureAlgorithm = signatureKeyPairData.SignatureAlgorithm,
SigningKey = signatureKeyPairData.WrappedSigningKey,
VerifyingKey = signatureKeyPairData.VerifyingKey,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow
};
newKeyPair.SetNewId();
await _userSignatureKeyPairRepository.CreateAsync(newKeyPair);
}
var jsonRequest = CreateV2SetPasswordRequestJson(
userEmail,
organization.Identifier,
"tde-test-hint",
includeAccountKeys: false);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync();
Assert.Fail($"Expected success but got {response.StatusCode}. Error: {errorContent}");
}
// Verify user in database
var updatedUser = await _userRepository.GetByEmailAsync(userEmail);
Assert.NotNull(updatedUser);
Assert.Equal("tde-test-hint", updatedUser.MasterPasswordHint);
// Verify the master password is hashed and stored
Assert.NotNull(updatedUser.MasterPassword);
var verificationResult = _passwordHasher.VerifyHashedPassword(updatedUser, updatedUser.MasterPassword, _newMasterPasswordHash);
Assert.Equal(PasswordVerificationResult.Success, verificationResult);
// Verify KDF settings
Assert.Equal(KdfType.PBKDF2_SHA256, updatedUser.Kdf);
Assert.Equal(600_000, updatedUser.KdfIterations);
Assert.Null(updatedUser.KdfMemory);
Assert.Null(updatedUser.KdfParallelism);
// Verify timestamps are updated
Assert.Equal(DateTime.UtcNow, updatedUser.RevisionDate, TimeSpan.FromMinutes(1));
Assert.Equal(DateTime.UtcNow, updatedUser.AccountRevisionDate, TimeSpan.FromMinutes(1));
// Verify key is set
Assert.Equal(_masterKeyWrappedUserKey, updatedUser.Key);
// Verify AccountKeys are preserved (TDE users already had V2 keys)
Assert.Equal("tde-publicKey", updatedUser.PublicKey);
Assert.Equal(_mockEncryptedType7String, updatedUser.PrivateKey);
Assert.Equal("tde-signedPublicKey", updatedUser.SignedPublicKey);
Assert.Equal(2, updatedUser.SecurityVersion);
Assert.Equal("v2-tde", updatedUser.SecurityState);
// Verify signature key pair is preserved (TDE users already had signature keys)
var signatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(updatedUser.Id);
Assert.NotNull(signatureKeyPair);
Assert.Equal(Ed25519, signatureKeyPair.SignatureAlgorithm);
Assert.Equal(_mockEncryptedType7WrappedSigningKey, signatureKeyPair.WrappedSigningKey);
Assert.Equal("tde-verifyingKey", signatureKeyPair.VerifyingKey);
// Verify User_ChangedPassword event was logged
var events = await _eventRepository.GetManyByUserAsync(updatedUser.Id, DateTime.UtcNow.AddMinutes(-5), DateTime.UtcNow.AddMinutes(1), new PageOptions { PageSize = 100 });
Assert.NotNull(events);
Assert.Contains(events.Data, e => e.Type == EventType.User_ChangedPassword && e.UserId == updatedUser.Id);
// Verify user remains confirmed in the organization
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(updatedUser.Id);
var orgUser = orgUsers.FirstOrDefault(ou => ou.OrganizationId == organization.Id);
Assert.NotNull(orgUser);
Assert.Equal(OrganizationUserStatusType.Confirmed, orgUser.Status);
}
[Fact]
public async Task PostSetPasswordAsync_V2_Unauthorized_ReturnsUnauthorized()
{
// Arrange - Don't login
var jsonRequest = CreateV2SetPasswordRequestJson(
"test@bitwarden.com",
"test-org-identifier",
"test-hint",
includeAccountKeys: true);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
}
[Fact]
public async Task PostSetPasswordAsync_V2_MismatchedKdfSettings_ReturnsBadRequest()
{
// Arrange
var email = $"kdf-mismatch-test-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(email);
await _loginHelper.LoginAsync(email);
// Test mismatched KDF settings (600000 vs 650000 iterations)
var request = new
{
masterPasswordAuthentication = new
{
kdf = new
{
kdfType = 0,
iterations = 600000
},
masterPasswordAuthenticationHash = _newMasterPasswordHash,
salt = email
},
masterPasswordUnlock = new
{
kdf = new
{
kdfType = 0,
iterations = 650000 // Different from authentication KDF
},
masterKeyWrappedUserKey = _masterKeyWrappedUserKey,
salt = email
},
accountKeys = new
{
userKeyEncryptedAccountPrivateKey = "7.AOs41Hd8OQiCPXjyJKCiDA==",
accountPublicKey = "public-key"
},
orgIdentifier = "test-org-identifier"
};
var jsonRequest = JsonSerializer.Serialize(request, JsonHelpers.CamelCase);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Theory]
[InlineData(KdfType.PBKDF2_SHA256, 1, null, null)]
[InlineData(KdfType.Argon2id, 4, null, 5)]
[InlineData(KdfType.Argon2id, 4, 65, null)]
public async Task PostSetPasswordAsync_V2_InvalidKdfSettings_ReturnsBadRequest(
KdfType kdf, int kdfIterations, int? kdfMemory, int? kdfParallelism)
{
// Arrange
var email = $"invalid-kdf-test-{Guid.NewGuid()}@bitwarden.com";
await _factory.LoginWithNewAccount(email);
await _loginHelper.LoginAsync(email);
var jsonRequest = CreateV2SetPasswordRequestJson(
email,
"test-org-identifier",
"test-hint",
includeAccountKeys: true,
kdfType: kdf,
kdfIterations: kdfIterations,
kdfMemory: kdfMemory,
kdfParallelism: kdfParallelism);
// Act
using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/set-password");
message.Content = new StringContent(jsonRequest, System.Text.Encoding.UTF8, "application/json");
var response = await _client.SendAsync(message);
// Assert
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
private static string CreateV2SetPasswordRequestJson(
string userEmail,
string orgIdentifier,
string hint,
bool includeAccountKeys = true,
KdfType? kdfType = null,
int? kdfIterations = null,
int? kdfMemory = null,
int? kdfParallelism = null)
{
var kdf = new
{
kdfType = (int)(kdfType ?? KdfType.PBKDF2_SHA256),
iterations = kdfIterations ?? 600000,
memory = kdfMemory,
parallelism = kdfParallelism
};
var request = new
{
masterPasswordAuthentication = new
{
kdf,
masterPasswordAuthenticationHash = _newMasterPasswordHash,
salt = userEmail
},
masterPasswordUnlock = new
{
kdf,
masterKeyWrappedUserKey = _masterKeyWrappedUserKey,
salt = userEmail
},
accountKeys = includeAccountKeys ? new
{
accountPublicKey = "publicKey",
userKeyEncryptedAccountPrivateKey = _mockEncryptedType7String,
publicKeyEncryptionKeyPair = new
{
publicKey = "publicKey",
wrappedPrivateKey = _mockEncryptedType7String,
signedPublicKey = "signedPublicKey"
},
signatureKeyPair = new
{
signatureAlgorithm = "ed25519",
wrappedSigningKey = _mockEncryptedType7WrappedSigningKey,
verifyingKey = "verifyingKey"
},
securityState = new
{
securityVersion = 2,
securityState = "v2"
}
} : null,
masterPasswordHint = hint,
orgIdentifier
};
return JsonSerializer.Serialize(request, JsonHelpers.CamelCase);
}
}

View File

@@ -49,6 +49,7 @@ public class ProfileOrganizationResponseModelTests
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UsePhishingBlocker = organization.UsePhishingBlocker,
UseDisableSMAdsForUsers = organization.UseDisableSmAdsForUsers,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,

View File

@@ -46,6 +46,7 @@ public class ProfileProviderOrganizationResponseModelTests
UseCustomPermissions = organization.UseCustomPermissions,
UseRiskInsights = organization.UseRiskInsights,
UsePhishingBlocker = organization.UsePhishingBlocker,
UseDisableSMAdsForUsers = organization.UseDisableSmAdsForUsers,
UseOrganizationDomains = organization.UseOrganizationDomains,
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies,
UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation,

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,14 +1,17 @@
using System.Security.Claims;
using Bit.Api.Auth.Controllers;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.KeyManagement.Models.Requests;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Services;
using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces;
using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces;
using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Kdf;
using Bit.Core.KeyManagement.Models.Api.Request;
@@ -33,7 +36,9 @@ public class AccountsControllerTests : IDisposable
private readonly IProviderUserRepository _providerUserRepository;
private readonly IPolicyService _policyService;
private readonly ISetInitialMasterPasswordCommand _setInitialMasterPasswordCommand;
private readonly ISetInitialMasterPasswordCommandV1 _setInitialMasterPasswordCommandV1;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly ITdeSetPasswordCommand _tdeSetPasswordCommand;
private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand;
private readonly IFeatureService _featureService;
private readonly IUserAccountKeysQuery _userAccountKeysQuery;
@@ -49,7 +54,9 @@ public class AccountsControllerTests : IDisposable
_providerUserRepository = Substitute.For<IProviderUserRepository>();
_policyService = Substitute.For<IPolicyService>();
_setInitialMasterPasswordCommand = Substitute.For<ISetInitialMasterPasswordCommand>();
_setInitialMasterPasswordCommandV1 = Substitute.For<ISetInitialMasterPasswordCommandV1>();
_twoFactorIsEnabledQuery = Substitute.For<ITwoFactorIsEnabledQuery>();
_tdeSetPasswordCommand = Substitute.For<ITdeSetPasswordCommand>();
_tdeOffboardingPasswordCommand = Substitute.For<ITdeOffboardingPasswordCommand>();
_featureService = Substitute.For<IFeatureService>();
_userAccountKeysQuery = Substitute.For<IUserAccountKeysQuery>();
@@ -64,6 +71,8 @@ public class AccountsControllerTests : IDisposable
_userService,
_policyService,
_setInitialMasterPasswordCommand,
_setInitialMasterPasswordCommandV1,
_tdeSetPasswordCommand,
_tdeOffboardingPasswordCommand,
_twoFactorIsEnabledQuery,
_featureService,
@@ -379,13 +388,13 @@ public class AccountsControllerTests : IDisposable
[BitAutoData(true, null, "newPublicKey", false)]
// reject overwriting existing keys
[BitAutoData(true, "newPrivateKey", "newPublicKey", false)]
public async Task PostSetPasswordAsync_WhenUserExistsAndSettingPasswordSucceeds_ShouldHandleKeysCorrectlyAndReturn(
public async Task PostSetPasswordAsync_V1_WhenUserExistsAndSettingPasswordSucceeds_ShouldHandleKeysCorrectlyAndReturn(
bool hasExistingKeys,
string requestPrivateKey,
string requestPublicKey,
bool shouldSucceed,
User user,
SetPasswordRequestModel setPasswordRequestModel)
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
const string existingPublicKey = "existingPublicKey";
@@ -402,13 +411,15 @@ public class AccountsControllerTests : IDisposable
user.PrivateKey = null;
}
UpdateSetInitialPasswordRequestModelToV1(setInitialPasswordRequestModel);
if (requestPrivateKey == null && requestPublicKey == null)
{
setPasswordRequestModel.Keys = null;
setInitialPasswordRequestModel.Keys = null;
}
else
{
setPasswordRequestModel.Keys = new KeysRequestModel
setInitialPasswordRequestModel.Keys = new KeysRequestModel
{
EncryptedPrivateKey = requestPrivateKey,
PublicKey = requestPublicKey
@@ -416,44 +427,44 @@ public class AccountsControllerTests : IDisposable
}
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_setInitialMasterPasswordCommand.SetInitialMasterPasswordAsync(
_setInitialMasterPasswordCommandV1.SetInitialMasterPasswordAsync(
user,
setPasswordRequestModel.MasterPasswordHash,
setPasswordRequestModel.Key,
setPasswordRequestModel.OrgIdentifier)
setInitialPasswordRequestModel.MasterPasswordHash,
setInitialPasswordRequestModel.Key,
setInitialPasswordRequestModel.OrgIdentifier)
.Returns(Task.FromResult(IdentityResult.Success));
// Act
if (shouldSucceed)
{
await _sut.PostSetPasswordAsync(setPasswordRequestModel);
await _sut.PostSetPasswordAsync(setInitialPasswordRequestModel);
// Assert
await _setInitialMasterPasswordCommand.Received(1)
await _setInitialMasterPasswordCommandV1.Received(1)
.SetInitialMasterPasswordAsync(
Arg.Is<User>(u => u == user),
Arg.Is<string>(s => s == setPasswordRequestModel.MasterPasswordHash),
Arg.Is<string>(s => s == setPasswordRequestModel.Key),
Arg.Is<string>(s => s == setPasswordRequestModel.OrgIdentifier));
Arg.Is<string>(s => s == setInitialPasswordRequestModel.MasterPasswordHash),
Arg.Is<string>(s => s == setInitialPasswordRequestModel.Key),
Arg.Is<string>(s => s == setInitialPasswordRequestModel.OrgIdentifier));
// Additional Assertions for User object modifications
Assert.Equal(setPasswordRequestModel.MasterPasswordHint, user.MasterPasswordHint);
Assert.Equal(setPasswordRequestModel.Kdf, user.Kdf);
Assert.Equal(setPasswordRequestModel.KdfIterations, user.KdfIterations);
Assert.Equal(setPasswordRequestModel.KdfMemory, user.KdfMemory);
Assert.Equal(setPasswordRequestModel.KdfParallelism, user.KdfParallelism);
Assert.Equal(setPasswordRequestModel.Key, user.Key);
Assert.Equal(setInitialPasswordRequestModel.MasterPasswordHint, user.MasterPasswordHint);
Assert.Equal(setInitialPasswordRequestModel.Kdf, user.Kdf);
Assert.Equal(setInitialPasswordRequestModel.KdfIterations, user.KdfIterations);
Assert.Equal(setInitialPasswordRequestModel.KdfMemory, user.KdfMemory);
Assert.Equal(setInitialPasswordRequestModel.KdfParallelism, user.KdfParallelism);
Assert.Equal(setInitialPasswordRequestModel.Key, user.Key);
}
else
{
await Assert.ThrowsAsync<BadRequestException>(() => _sut.PostSetPasswordAsync(setPasswordRequestModel));
await Assert.ThrowsAsync<BadRequestException>(() => _sut.PostSetPasswordAsync(setInitialPasswordRequestModel));
}
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_WhenUserExistsAndHasKeysAndKeysAreUpdated_ShouldThrowAsync(
public async Task PostSetPasswordAsync_V1_WhenUserExistsAndHasKeysAndKeysAreUpdated_ShouldThrowAsync(
User user,
SetPasswordRequestModel setPasswordRequestModel)
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
const string existingPublicKey = "existingPublicKey";
@@ -465,47 +476,52 @@ public class AccountsControllerTests : IDisposable
user.PublicKey = existingPublicKey;
user.PrivateKey = existingEncryptedPrivateKey;
setPasswordRequestModel.Keys = new KeysRequestModel()
UpdateSetInitialPasswordRequestModelToV1(setInitialPasswordRequestModel);
setInitialPasswordRequestModel.Keys = new KeysRequestModel()
{
PublicKey = newPublicKey,
EncryptedPrivateKey = newEncryptedPrivateKey
};
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_setInitialMasterPasswordCommand.SetInitialMasterPasswordAsync(
_setInitialMasterPasswordCommandV1.SetInitialMasterPasswordAsync(
user,
setPasswordRequestModel.MasterPasswordHash,
setPasswordRequestModel.Key,
setPasswordRequestModel.OrgIdentifier)
setInitialPasswordRequestModel.MasterPasswordHash,
setInitialPasswordRequestModel.Key,
setInitialPasswordRequestModel.OrgIdentifier)
.Returns(Task.FromResult(IdentityResult.Success));
// Act & Assert
await Assert.ThrowsAsync<BadRequestException>(() => _sut.PostSetPasswordAsync(setPasswordRequestModel));
await Assert.ThrowsAsync<BadRequestException>(() => _sut.PostSetPasswordAsync(setInitialPasswordRequestModel));
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException(
SetPasswordRequestModel setPasswordRequestModel)
public async Task PostSetPasswordAsync_V1_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException(
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
UpdateSetInitialPasswordRequestModelToV1(setInitialPasswordRequestModel);
// Arrange
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult((User)null));
// Act & Assert
await Assert.ThrowsAsync<UnauthorizedAccessException>(() => _sut.PostSetPasswordAsync(setPasswordRequestModel));
await Assert.ThrowsAsync<UnauthorizedAccessException>(() => _sut.PostSetPasswordAsync(setInitialPasswordRequestModel));
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_WhenSettingPasswordFails_ShouldThrowBadRequestException(
public async Task PostSetPasswordAsync_V1_WhenSettingPasswordFails_ShouldThrowBadRequestException(
User user,
SetPasswordRequestModel model)
SetInitialPasswordRequestModel model)
{
UpdateSetInitialPasswordRequestModelToV1(model);
model.Keys = null;
// Arrange
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_setInitialMasterPasswordCommand.SetInitialMasterPasswordAsync(Arg.Any<User>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
_setInitialMasterPasswordCommandV1.SetInitialMasterPasswordAsync(Arg.Any<User>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>())
.Returns(Task.FromResult(IdentityResult.Failed(new IdentityError { Description = "Some Error" })));
// Act & Assert
@@ -845,5 +861,139 @@ public class AccountsControllerTests : IDisposable
Assert.NotNull(result);
Assert.Equal("keys", result.Object);
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_WhenUserExistsAndSettingPasswordSucceeds_ShouldSetInitialMasterPassword(
User user,
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
UpdateSetInitialPasswordRequestModelToV2(setInitialPasswordRequestModel);
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_setInitialMasterPasswordCommand.SetInitialMasterPasswordAsync(user, Arg.Any<SetInitialMasterPasswordDataModel>())
.Returns(Task.CompletedTask);
// Act
await _sut.PostSetPasswordAsync(setInitialPasswordRequestModel);
// Assert
await _setInitialMasterPasswordCommand.Received(1)
.SetInitialMasterPasswordAsync(
Arg.Is<User>(u => u == user),
Arg.Is<SetInitialMasterPasswordDataModel>(d =>
d.MasterPasswordAuthentication != null &&
d.MasterPasswordUnlock != null &&
d.AccountKeys != null &&
d.OrgSsoIdentifier == setInitialPasswordRequestModel.OrgIdentifier));
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_WithTdeSetPassword_ShouldCallTdeSetPasswordCommand(
User user,
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
UpdateSetInitialPasswordRequestModelToV2(setInitialPasswordRequestModel, includeTdeSetPassword: true);
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_tdeSetPasswordCommand.SetMasterPasswordAsync(user, Arg.Any<SetInitialMasterPasswordDataModel>())
.Returns(Task.CompletedTask);
// Act
await _sut.PostSetPasswordAsync(setInitialPasswordRequestModel);
// Assert
await _tdeSetPasswordCommand.Received(1)
.SetMasterPasswordAsync(
Arg.Is<User>(u => u == user),
Arg.Is<SetInitialMasterPasswordDataModel>(d =>
d.MasterPasswordAuthentication != null &&
d.MasterPasswordUnlock != null &&
d.AccountKeys == null &&
d.OrgSsoIdentifier == setInitialPasswordRequestModel.OrgIdentifier));
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException(
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
UpdateSetInitialPasswordRequestModelToV2(setInitialPasswordRequestModel);
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult((User)null));
// Act & Assert
await Assert.ThrowsAsync<UnauthorizedAccessException>(() => _sut.PostSetPasswordAsync(setInitialPasswordRequestModel));
}
[Theory]
[BitAutoData]
public async Task PostSetPasswordAsync_V2_WhenSettingPasswordFails_ShouldThrowException(
User user,
SetInitialPasswordRequestModel setInitialPasswordRequestModel)
{
// Arrange
UpdateSetInitialPasswordRequestModelToV2(setInitialPasswordRequestModel);
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(Task.FromResult(user));
_setInitialMasterPasswordCommand.SetInitialMasterPasswordAsync(user, Arg.Any<SetInitialMasterPasswordDataModel>())
.Returns(Task.FromException(new Exception("Setting password failed")));
// Act & Assert
await Assert.ThrowsAsync<Exception>(() => _sut.PostSetPasswordAsync(setInitialPasswordRequestModel));
}
private void UpdateSetInitialPasswordRequestModelToV1(SetInitialPasswordRequestModel model)
{
model.MasterPasswordAuthentication = null;
model.MasterPasswordUnlock = null;
model.AccountKeys = null;
}
private void UpdateSetInitialPasswordRequestModelToV2(SetInitialPasswordRequestModel model, bool includeTdeSetPassword = false)
{
var kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
};
model.MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = kdf,
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
};
model.MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = kdf,
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
};
if (includeTdeSetPassword)
{
// TDE set password does not include AccountKeys
model.AccountKeys = null;
}
else
{
model.AccountKeys = new AccountKeysRequestModel
{
UserKeyEncryptedAccountPrivateKey = "privateKey",
AccountPublicKey = "publicKey"
};
}
// Clear V1 properties
model.MasterPasswordHash = null;
model.Key = null;
model.Keys = null;
model.Kdf = null;
model.KdfIterations = null;
model.KdfMemory = null;
model.KdfParallelism = null;
}
}

View File

@@ -0,0 +1,682 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.KeyManagement.Models.Requests;
using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Models.Api.Request;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Api.Test.Auth.Models.Request.Accounts;
public class SetInitialPasswordRequestModelTests
{
#region V2 Validation Tests
[Theory]
[InlineData(KdfType.PBKDF2_SHA256, 600000, null, null)]
[InlineData(KdfType.Argon2id, 3, 64, 4)]
public void Validate_V2Request_WithMatchingKdf_ReturnsNoErrors(KdfType kdfType, int iterations, int? memory, int? parallelism)
{
// Arrange
var kdf = new KdfRequestModel
{
KdfType = kdfType,
Iterations = iterations,
Memory = memory,
Parallelism = parallelism
};
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = "orgIdentifier",
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = kdf,
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = kdf,
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
},
AccountKeys = new AccountKeysRequestModel
{
UserKeyEncryptedAccountPrivateKey = "privateKey",
AccountPublicKey = "publicKey"
}
};
// Act
var result = model.Validate(new ValidationContext(model));
// Assert
Assert.Empty(result);
}
[Theory]
[BitAutoData]
public void Validate_V2Request_WithMismatchedKdfSettings_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 650000 // Different iterations
},
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
}
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Single(result);
Assert.Contains("KDF settings must be equal", result[0].ErrorMessage);
var memberNames = result[0].MemberNames.ToList();
Assert.Equal(2, memberNames.Count);
Assert.Contains("MasterPasswordAuthentication.Kdf", memberNames);
Assert.Contains("MasterPasswordUnlock.Kdf", memberNames);
}
[Theory]
[BitAutoData]
public void Validate_V2Request_WithInvalidAuthenticationKdf_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 1 // Too low
};
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = kdf,
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = kdf,
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
}
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.NotEmpty(result);
Assert.Contains(result, r => r.ErrorMessage != null && r.ErrorMessage.Contains("KDF iterations must be between"));
}
#endregion
#region V1 Validation Tests (Obsolete)
[Theory]
[BitAutoData]
public void Validate_V1Request_WithMissingMasterPasswordHash_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
Key = "key",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 600000
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage.Contains("MasterPasswordHash must be supplied"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithMissingKey_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 600000
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage.Contains("Key must be supplied"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithMissingKdf_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
KdfIterations = 600000
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage != null && r.ErrorMessage.Contains("Kdf must be supplied"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithMissingKdfIterations_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
Kdf = KdfType.PBKDF2_SHA256
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage != null && r.ErrorMessage.Contains("KdfIterations must be supplied"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithArgon2idAndMissingMemory_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
Kdf = KdfType.Argon2id,
KdfIterations = 3,
KdfParallelism = 4
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage.Contains("KdfMemory must be supplied when Kdf is Argon2id"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithArgon2idAndMissingParallelism_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
Kdf = KdfType.Argon2id,
KdfIterations = 3,
KdfMemory = 64
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.Contains(result, r => r.ErrorMessage.Contains("KdfParallelism must be supplied when Kdf is Argon2id"));
}
[Theory]
[BitAutoData]
public void Validate_V1Request_WithInvalidKdfSettings_ReturnsValidationError(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 5000 // Too low
};
// Act
var result = model.Validate(new ValidationContext(model)).ToList();
// Assert
Assert.NotEmpty(result);
Assert.Contains(result, r => r.ErrorMessage != null && r.ErrorMessage.Contains("KDF iterations must be between"));
}
[Theory]
[InlineData(KdfType.PBKDF2_SHA256, 600000, null, null)]
[InlineData(KdfType.Argon2id, 3, 64, 4)]
public void Validate_V1Request_WithValidSettings_ReturnsNoErrors(KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = "orgIdentifier",
MasterPasswordHash = "hash",
Key = "key",
Kdf = kdfType,
KdfIterations = kdfIterations,
KdfMemory = kdfMemory,
KdfParallelism = kdfParallelism
};
// Act
var result = model.Validate(new ValidationContext(model));
// Assert
Assert.Empty(result);
}
#endregion
#region IsV2Request Tests
[Theory]
[BitAutoData]
public void IsV2Request_WithV2Properties_ReturnsTrue(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
}
};
// Act
var result = model.IsV2Request();
// Assert
Assert.True(result);
}
[Theory]
[BitAutoData]
public void IsV2Request_WithoutMasterPasswordAuthentication_ReturnsFalse(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
}
};
// Act
var result = model.IsV2Request();
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public void IsV2Request_WithoutMasterPasswordUnlock_ReturnsFalse(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
}
};
// Act
var result = model.IsV2Request();
// Assert
Assert.False(result);
}
[Theory]
[BitAutoData]
public void IsV2Request_WithV1Properties_ReturnsFalse(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHash = "hash",
Key = "key",
Kdf = KdfType.PBKDF2_SHA256,
KdfIterations = 600000
};
// Act
var result = model.IsV2Request();
// Assert
Assert.False(result);
}
#endregion
#region IsTdeSetPasswordRequest Tests
[Theory]
[BitAutoData]
public void IsTdeSetPasswordRequest_WithNullAccountKeys_ReturnsTrue(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
},
AccountKeys = null
};
// Act
var result = model.IsTdeSetPasswordRequest();
// Assert
Assert.True(result);
}
[Theory]
[BitAutoData]
public void IsTdeSetPasswordRequest_WithAccountKeys_ReturnsFalse(string orgIdentifier)
{
// Arrange
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
},
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
},
AccountKeys = new AccountKeysRequestModel
{
UserKeyEncryptedAccountPrivateKey = "privateKey",
AccountPublicKey = "publicKey"
}
};
// Act
var result = model.IsTdeSetPasswordRequest();
// Assert
Assert.False(result);
}
#endregion
#region ToUser Tests (Obsolete)
[Theory]
[InlineData(KdfType.PBKDF2_SHA256, 600000, null, null)]
[InlineData(KdfType.Argon2id, 3, 64, 4)]
public void ToUser_WithKeys_MapsPropertiesCorrectly(KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism)
{
// Arrange
var existingUser = new User();
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = "orgIdentifier",
MasterPasswordHash = "hash",
MasterPasswordHint = "hint",
Key = "key",
Kdf = kdfType,
KdfIterations = kdfIterations,
KdfMemory = kdfMemory,
KdfParallelism = kdfParallelism,
Keys = new KeysRequestModel
{
PublicKey = "publicKey",
EncryptedPrivateKey = "encryptedPrivateKey"
}
};
// Act
var result = model.ToUser(existingUser);
// Assert
Assert.Same(existingUser, result);
Assert.Equal("hint", result.MasterPasswordHint);
Assert.Equal(kdfType, result.Kdf);
Assert.Equal(kdfIterations, result.KdfIterations);
Assert.Equal(kdfMemory, result.KdfMemory);
Assert.Equal(kdfParallelism, result.KdfParallelism);
Assert.Equal("key", result.Key);
Assert.Equal("publicKey", result.PublicKey);
Assert.Equal("encryptedPrivateKey", result.PrivateKey);
}
[Theory]
[InlineData(KdfType.PBKDF2_SHA256, 600000, null, null)]
[InlineData(KdfType.Argon2id, 3, 64, 4)]
public void ToUser_WithoutKeys_MapsPropertiesCorrectly(KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism)
{
// Arrange
var existingUser = new User();
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = "orgIdentifier",
MasterPasswordHash = "hash",
MasterPasswordHint = "hint",
Key = "key",
Kdf = kdfType,
KdfIterations = kdfIterations,
KdfMemory = kdfMemory,
KdfParallelism = kdfParallelism,
Keys = null
};
// Act
var result = model.ToUser(existingUser);
// Assert
Assert.Same(existingUser, result);
Assert.Equal("hint", result.MasterPasswordHint);
Assert.Equal(kdfType, result.Kdf);
Assert.Equal(kdfIterations, result.KdfIterations);
Assert.Equal(kdfMemory, result.KdfMemory);
Assert.Equal(kdfParallelism, result.KdfParallelism);
Assert.Equal("key", result.Key);
Assert.Null(result.PublicKey);
Assert.Null(result.PrivateKey);
}
#endregion
#region ToData Tests
[Theory]
[BitAutoData]
public void ToData_MapsPropertiesCorrectly(string orgIdentifier)
{
// Arrange
var kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
};
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHint = "hint",
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = kdf,
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = kdf,
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
},
AccountKeys = new AccountKeysRequestModel
{
UserKeyEncryptedAccountPrivateKey = "privateKey",
AccountPublicKey = "publicKey"
}
};
// Act
var result = model.ToData();
// Assert
Assert.NotNull(result);
Assert.Equal(orgIdentifier, result.OrgSsoIdentifier);
Assert.Equal("hint", result.MasterPasswordHint);
Assert.NotNull(result.MasterPasswordAuthentication);
Assert.NotNull(result.MasterPasswordUnlock);
Assert.NotNull(result.AccountKeys);
Assert.Equal("authHash", result.MasterPasswordAuthentication.MasterPasswordAuthenticationHash);
Assert.Equal("wrappedKey", result.MasterPasswordUnlock.MasterKeyWrappedUserKey);
}
[Theory]
[BitAutoData]
public void ToData_WithNullAccountKeys_MapsCorrectly(string orgIdentifier)
{
// Arrange
var kdf = new KdfRequestModel
{
KdfType = KdfType.PBKDF2_SHA256,
Iterations = 600000
};
var model = new SetInitialPasswordRequestModel
{
OrgIdentifier = orgIdentifier,
MasterPasswordHint = "hint",
MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel
{
Kdf = kdf,
MasterPasswordAuthenticationHash = "authHash",
Salt = "salt"
},
MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel
{
Kdf = kdf,
MasterKeyWrappedUserKey = "wrappedKey",
Salt = "salt"
},
AccountKeys = null
};
// Act
var result = model.ToData();
// Assert
Assert.NotNull(result);
Assert.Equal(orgIdentifier, result.OrgSsoIdentifier);
Assert.Null(result.AccountKeys);
}
#endregion
}

View File

@@ -3,6 +3,8 @@ using Bit.Api.Billing.Models.Requests.Storage;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Licenses.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Billing.Subscriptions.Queries;
using Bit.Core.Entities;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
@@ -17,21 +19,26 @@ public class AccountBillingVNextControllerTests
{
private readonly IUpdatePremiumStorageCommand _updatePremiumStorageCommand;
private readonly IGetUserLicenseQuery _getUserLicenseQuery;
private readonly IUpgradePremiumToOrganizationCommand _upgradePremiumToOrganizationCommand;
private readonly AccountBillingVNextController _sut;
public AccountBillingVNextControllerTests()
{
_updatePremiumStorageCommand = Substitute.For<IUpdatePremiumStorageCommand>();
_getUserLicenseQuery = Substitute.For<IGetUserLicenseQuery>();
_upgradePremiumToOrganizationCommand = Substitute.For<IUpgradePremiumToOrganizationCommand>();
_sut = new AccountBillingVNextController(
Substitute.For<Core.Billing.Payment.Commands.ICreateBitPayInvoiceForCreditCommand>(),
Substitute.For<Core.Billing.Premium.Commands.ICreatePremiumCloudHostedSubscriptionCommand>(),
Substitute.For<IGetBitwardenSubscriptionQuery>(),
Substitute.For<Core.Billing.Payment.Queries.IGetCreditQuery>(),
Substitute.For<Core.Billing.Payment.Queries.IGetPaymentMethodQuery>(),
_getUserLicenseQuery,
Substitute.For<IReinstateSubscriptionCommand>(),
Substitute.For<Core.Billing.Payment.Commands.IUpdatePaymentMethodCommand>(),
_updatePremiumStorageCommand);
_updatePremiumStorageCommand,
_upgradePremiumToOrganizationCommand);
}
[Theory, BitAutoData]
@@ -60,7 +67,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -80,7 +87,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -100,7 +107,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -120,7 +127,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -140,7 +147,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -160,7 +167,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -179,7 +186,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -198,7 +205,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -217,7 +224,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -236,7 +243,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);

View File

@@ -8,8 +8,8 @@ using Bit.Api.Tools.Models.Request;
using Bit.Api.Tools.Models.Response;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
@@ -28,7 +28,6 @@ namespace Bit.Api.Test.Tools.Controllers;
public class SendsControllerTests : IDisposable
{
private readonly SendsController _sut;
private readonly GlobalSettings _globalSettings;
private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
@@ -37,6 +36,8 @@ public class SendsControllerTests : IDisposable
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly ILogger<SendsController> _logger;
private readonly IFeatureService _featureService;
private readonly IPushNotificationService _pushNotificationService;
public SendsControllerTests()
{
@@ -47,8 +48,9 @@ public class SendsControllerTests : IDisposable
_sendOwnerQuery = Substitute.For<ISendOwnerQuery>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
_globalSettings = new GlobalSettings();
_logger = Substitute.For<ILogger<SendsController>>();
_featureService = Substitute.For<IFeatureService>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_sut = new SendsController(
_sendRepository,
@@ -59,7 +61,8 @@ public class SendsControllerTests : IDisposable
_sendOwnerQuery,
_sendFileStorageService,
_logger,
_globalSettings
_featureService,
_pushNotificationService
);
}
@@ -96,8 +99,8 @@ public class SendsControllerTests : IDisposable
{
var now = DateTime.UtcNow;
var expected = "You cannot have a Send with a deletion date that far " +
"into the future. Adjust the Deletion Date to a value less than 31 days from now " +
"and try again.";
"into the future. Adjust the Deletion Date to a value less than 31 days from now " +
"and try again.";
var request = new SendRequestModel() { DeletionDate = now.AddDays(32) };
var exception = await Assert.ThrowsAsync<BadRequestException>(() => _sut.Post(request));
@@ -109,9 +112,10 @@ public class SendsControllerTests : IDisposable
{
var now = DateTime.UtcNow;
var expected = "You cannot have a Send with a deletion date that far " +
"into the future. Adjust the Deletion Date to a value less than 31 days from now " +
"and try again.";
var request = new SendRequestModel() { Type = SendType.File, FileLength = 1024L, DeletionDate = now.AddDays(32) };
"into the future. Adjust the Deletion Date to a value less than 31 days from now " +
"and try again.";
var request =
new SendRequestModel() { Type = SendType.File, FileLength = 1024L, DeletionDate = now.AddDays(32) };
var exception = await Assert.ThrowsAsync<BadRequestException>(() => _sut.PostFile(request));
Assert.Equal(expected, exception.Message);
@@ -409,7 +413,8 @@ public class SendsControllerTests : IDisposable
}
[Theory, AutoData]
public async Task PutRemovePassword_WithWrongUser_ThrowsNotFoundException(Guid userId, Guid otherUserId, Guid sendId)
public async Task PutRemovePassword_WithWrongUser_ThrowsNotFoundException(Guid userId, Guid otherUserId,
Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
@@ -753,4 +758,683 @@ public class SendsControllerTests : IDisposable
s.Password == null &&
s.Emails == null));
}
#region Authenticated Access Endpoints
[Theory, AutoData]
public async Task AccessUsingAuth_WithValidSend_ReturnsSendAccessResponse(Guid sendId, User creator)
{
var send = new Send
{
Id = sendId,
UserId = creator.Id,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
HideEmail = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
_userService.GetUserByIdAsync(creator.Id).Returns(creator);
var result = await _sut.AccessUsingAuth();
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendAccessResponseModel>(objectResult.Value);
Assert.Equal(CoreHelpers.Base64UrlEncode(sendId.ToByteArray()), response.Id);
Assert.Equal(creator.Email, response.CreatorIdentifier);
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _userService.Received(1).GetUserByIdAsync(creator.Id);
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithHideEmail_DoesNotIncludeCreatorIdentifier(Guid sendId, User creator)
{
var send = new Send
{
Id = sendId,
UserId = creator.Id,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
HideEmail = true,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
var result = await _sut.AccessUsingAuth();
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendAccessResponseModel>(objectResult.Value);
Assert.Equal(CoreHelpers.Base64UrlEncode(sendId.ToByteArray()), response.Id);
Assert.Null(response.CreatorIdentifier);
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any<Guid>());
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithNoUserId_DoesNotIncludeCreatorIdentifier(Guid sendId)
{
var send = new Send
{
Id = sendId,
UserId = null,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
HideEmail = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
var result = await _sut.AccessUsingAuth();
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendAccessResponseModel>(objectResult.Value);
Assert.Equal(CoreHelpers.Base64UrlEncode(sendId.ToByteArray()), response.Id);
Assert.Null(response.CreatorIdentifier);
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any<Guid>());
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithNonExistentSend_ThrowsBadRequestException(Guid sendId)
{
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns((Send)null);
var exception =
await Assert.ThrowsAsync<BadRequestException>(() => _sut.AccessUsingAuth());
Assert.Equal("Could not locate send", exception.Message);
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithFileSend_ReturnsCorrectResponse(Guid sendId, User creator)
{
var fileData = new SendFileData("Test File", "Notes", "document.pdf") { Id = "file-123", Size = 2048 };
var send = new Send
{
Id = sendId,
UserId = creator.Id,
Type = SendType.File,
Data = JsonSerializer.Serialize(fileData),
HideEmail = false,
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
_userService.GetUserByIdAsync(creator.Id).Returns(creator);
var result = await _sut.AccessUsingAuth();
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendAccessResponseModel>(objectResult.Value);
Assert.Equal(CoreHelpers.Base64UrlEncode(sendId.ToByteArray()), response.Id);
Assert.Equal(SendType.File, response.Type);
Assert.NotNull(response.File);
Assert.Equal("file-123", response.File.Id);
Assert.Equal(creator.Email, response.CreatorIdentifier);
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithValidFileId_ReturnsDownloadUrl(
Guid sendId, string fileId, string expectedUrl)
{
var fileData = new SendFileData("Test File", "Notes", "document.pdf") { Id = fileId, Size = 2048 };
var send = new Send
{
Id = sendId,
Type = SendType.File,
Data = JsonSerializer.Serialize(fileData),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
_sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl);
var result = await _sut.GetSendFileDownloadDataUsingAuth(fileId);
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendFileDownloadDataResponseModel>(objectResult.Value);
Assert.Equal(fileId, response.Id);
Assert.Equal(expectedUrl, response.Url);
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _sendFileStorageService.Received(1).GetSendFileDownloadUrlAsync(send, fileId);
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithNonExistentSend_ThrowsBadRequestException(
Guid sendId, string fileId)
{
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns((Send)null);
var exception =
await Assert.ThrowsAsync<BadRequestException>(() => _sut.GetSendFileDownloadDataUsingAuth(fileId));
Assert.Equal("Could not locate send", exception.Message);
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _sendFileStorageService.DidNotReceive()
.GetSendFileDownloadUrlAsync(Arg.Any<Send>(), Arg.Any<string>());
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithTextSend_StillReturnsResponse(
Guid sendId, string fileId, string expectedUrl)
{
var send = new Send
{
Id = sendId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
_sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl);
var result = await _sut.GetSendFileDownloadDataUsingAuth(fileId);
Assert.NotNull(result);
var objectResult = Assert.IsType<ObjectResult>(result);
var response = Assert.IsType<SendFileDownloadDataResponseModel>(objectResult.Value);
Assert.Equal(fileId, response.Id);
Assert.Equal(expectedUrl, response.Url);
}
#region AccessUsingAuth Validation Tests
[Theory, AutoData]
public async Task AccessUsingAuth_WithExpiredSend_ThrowsNotFoundException(Guid sendId)
{
var send = new Send
{
Id = sendId,
UserId = Guid.NewGuid(),
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired yesterday
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.AccessUsingAuth());
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithDeletedSend_ThrowsNotFoundException(Guid sendId)
{
var send = new Send
{
Id = sendId,
UserId = Guid.NewGuid(),
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
DeletionDate = DateTime.UtcNow.AddDays(-1), // Should have been deleted yesterday
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.AccessUsingAuth());
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithDisabledSend_ThrowsNotFoundException(Guid sendId)
{
var send = new Send
{
Id = sendId,
UserId = Guid.NewGuid(),
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = true, // Disabled
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.AccessUsingAuth());
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task AccessUsingAuth_WithAccessCountExceeded_ThrowsNotFoundException(Guid sendId)
{
var send = new Send
{
Id = sendId,
UserId = Guid.NewGuid(),
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 5,
MaxAccessCount = 5 // Limit reached
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.AccessUsingAuth());
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
#endregion
#region GetSendFileDownloadDataUsingAuth Validation Tests
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithExpiredSend_ThrowsNotFoundException(
Guid sendId, string fileId)
{
var send = new Send
{
Id = sendId,
Type = SendType.File,
Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.GetSendFileDownloadDataUsingAuth(fileId));
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithDeletedSend_ThrowsNotFoundException(
Guid sendId, string fileId)
{
var send = new Send
{
Id = sendId,
Type = SendType.File,
Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")),
DeletionDate = DateTime.UtcNow.AddDays(-1), // Deleted
ExpirationDate = null,
Disabled = false,
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.GetSendFileDownloadDataUsingAuth(fileId));
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithDisabledSend_ThrowsNotFoundException(
Guid sendId, string fileId)
{
var send = new Send
{
Id = sendId,
Type = SendType.File,
Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = true, // Disabled
AccessCount = 0,
MaxAccessCount = null
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.GetSendFileDownloadDataUsingAuth(fileId));
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
[Theory, AutoData]
public async Task GetSendFileDownloadDataUsingAuth_WithAccessCountExceeded_ThrowsNotFoundException(
Guid sendId, string fileId)
{
var send = new Send
{
Id = sendId,
Type = SendType.File,
Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")),
DeletionDate = DateTime.UtcNow.AddDays(7),
ExpirationDate = null,
Disabled = false,
AccessCount = 10,
MaxAccessCount = 10 // Limit reached
};
var user = CreateUserWithSendIdClaim(sendId);
_sut.ControllerContext = CreateControllerContextWithUser(user);
_sendRepository.GetByIdAsync(sendId).Returns(send);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.GetSendFileDownloadDataUsingAuth(fileId));
await _sendRepository.Received(1).GetByIdAsync(sendId);
}
#endregion
#endregion
#region PutRemoveAuth Tests
[Theory, AutoData]
public async Task PutRemoveAuth_WithPasswordProtectedSend_RemovesPasswordAndSetsAuthTypeNone(Guid userId,
Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = "hashed-password",
Emails = null,
AuthType = AuthType.Password
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
Assert.Null(result.Password);
Assert.Null(result.Emails);
await _nonAnonymousSendCommand.Received(1).SaveSendAsync(Arg.Is<Send>(s =>
s.Id == sendId &&
s.Password == null &&
s.Emails == null &&
s.AuthType == AuthType.None));
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithEmailProtectedSend_RemovesEmailsAndSetsAuthTypeNone(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = null,
Emails = "test@example.com,user@example.com",
AuthType = AuthType.Email
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
Assert.Null(result.Password);
Assert.Null(result.Emails);
await _nonAnonymousSendCommand.Received(1).SaveSendAsync(Arg.Is<Send>(s =>
s.Id == sendId &&
s.Password == null &&
s.Emails == null &&
s.AuthType == AuthType.None));
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithSendAlreadyHavingNoAuth_StillSucceeds(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = null,
Emails = null,
AuthType = AuthType.None
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
Assert.Null(result.Password);
Assert.Null(result.Emails);
await _nonAnonymousSendCommand.Received(1).SaveSendAsync(Arg.Is<Send>(s =>
s.Id == sendId &&
s.Password == null &&
s.Emails == null &&
s.AuthType == AuthType.None));
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithFileSend_RemovesAuthAndPreservesFileData(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var fileData = new SendFileData("Test File", "Notes", "document.pdf") { Id = "file-123", Size = 2048 };
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.File,
Data = JsonSerializer.Serialize(fileData),
Password = "hashed-password",
Emails = null,
AuthType = AuthType.Password
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
Assert.Equal(SendType.File, result.Type);
Assert.NotNull(result.File);
Assert.Equal("file-123", result.File.Id);
Assert.Null(result.Password);
Assert.Null(result.Emails);
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithNonExistentSend_ThrowsNotFoundException(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
_sendRepository.GetByIdAsync(sendId).Returns((Send)null);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.PutRemoveAuth(sendId.ToString()));
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _nonAnonymousSendCommand.DidNotReceive().SaveSendAsync(Arg.Any<Send>());
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithWrongUser_ThrowsNotFoundException(Guid userId, Guid otherUserId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
{
Id = sendId,
UserId = otherUserId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = "hashed-password",
AuthType = AuthType.Password
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
await Assert.ThrowsAsync<NotFoundException>(() => _sut.PutRemoveAuth(sendId.ToString()));
await _sendRepository.Received(1).GetByIdAsync(sendId);
await _nonAnonymousSendCommand.DidNotReceive().SaveSendAsync(Arg.Any<Send>());
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithNullUserId_ThrowsInvalidOperationException(Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns((Guid?)null);
var exception =
await Assert.ThrowsAsync<InvalidOperationException>(() => _sut.PutRemoveAuth(sendId.ToString()));
Assert.Equal("User ID not found", exception.Message);
await _sendRepository.DidNotReceive().GetByIdAsync(Arg.Any<Guid>());
await _nonAnonymousSendCommand.DidNotReceive().SaveSendAsync(Arg.Any<Send>());
}
[Theory, AutoData]
public async Task PutRemoveAuth_WithSendHavingBothPasswordAndEmails_RemovesBoth(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = "hashed-password",
Emails = "test@example.com",
AuthType = AuthType.Password
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
Assert.Null(result.Password);
Assert.Null(result.Emails);
await _nonAnonymousSendCommand.Received(1).SaveSendAsync(Arg.Is<Send>(s =>
s.Id == sendId &&
s.Password == null &&
s.Emails == null &&
s.AuthType == AuthType.None));
}
[Theory, AutoData]
public async Task PutRemoveAuth_PreservesOtherSendProperties(Guid userId, Guid sendId)
{
_userService.GetProperUserId(Arg.Any<ClaimsPrincipal>()).Returns(userId);
var deletionDate = DateTime.UtcNow.AddDays(7);
var expirationDate = DateTime.UtcNow.AddDays(3);
var existingSend = new Send
{
Id = sendId,
UserId = userId,
Type = SendType.Text,
Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)),
Password = "hashed-password",
AuthType = AuthType.Password,
Key = "encryption-key",
MaxAccessCount = 10,
AccessCount = 3,
DeletionDate = deletionDate,
ExpirationDate = expirationDate,
Disabled = false,
HideEmail = true
};
_sendRepository.GetByIdAsync(sendId).Returns(existingSend);
var result = await _sut.PutRemoveAuth(sendId.ToString());
Assert.NotNull(result);
Assert.Equal(sendId, result.Id);
Assert.Equal(AuthType.None, result.AuthType);
// Verify other properties are preserved
Assert.Equal("encryption-key", result.Key);
Assert.Equal(10, result.MaxAccessCount);
Assert.Equal(3, result.AccessCount);
Assert.Equal(deletionDate, result.DeletionDate);
Assert.Equal(expirationDate, result.ExpirationDate);
Assert.False(result.Disabled);
Assert.True(result.HideEmail);
}
#endregion
#region Test Helpers
private static ClaimsPrincipal CreateUserWithSendIdClaim(Guid sendId)
{
var claims = new List<Claim> { new Claim("send_id", sendId.ToString()) };
var identity = new ClaimsIdentity(claims, "TestAuth");
return new ClaimsPrincipal(identity);
}
private static ControllerContext CreateControllerContextWithUser(ClaimsPrincipal user)
{
return new ControllerContext { HttpContext = new Microsoft.AspNetCore.Http.DefaultHttpContext { User = user } };
}
#endregion
}

View File

@@ -1,4 +1,5 @@
using System.Text;
using System.Globalization;
using System.Text;
using Bit.Billing.Controllers;
using Bit.Billing.Test.Utilities;
using Bit.Core.AdminConsole.Entities;
@@ -565,4 +566,53 @@ public class PayPalControllerTests(ITestOutputHelper testOutputHelper)
private static void LoggedWarning(ICacheLogger<PayPalController> logger, string message)
=> Logged(logger, LogLevel.Warning, message);
[Fact]
public async Task PostIpn_Completed_CreatesTransaction_WithSwedishCulture_Ok()
{
// Save current culture
var originalCulture = CultureInfo.CurrentCulture;
var originalUICulture = CultureInfo.CurrentUICulture;
try
{
// Set Swedish culture (uses comma as decimal separator)
var swedishCulture = new CultureInfo("sv-SE");
CultureInfo.CurrentCulture = swedishCulture;
CultureInfo.CurrentUICulture = swedishCulture;
var logger = testOutputHelper.BuildLoggerFor<PayPalController>();
_billingSettings.Value.Returns(new BillingSettings
{
PayPal =
{
WebhookKey = _defaultWebhookKey,
BusinessId = "NHDYKLQ3L4LWL"
}
});
var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment);
_transactionRepository.GetByGatewayIdAsync(
GatewayType.PayPal,
"2PK15573S8089712Y").ReturnsNull();
var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody);
var result = await controller.PostIpn();
HasStatusCode(result, 200);
await _transactionRepository.Received().CreateAsync(Arg.Is<Transaction>(transaction =>
transaction.Amount == 48M &&
transaction.GatewayId == "2PK15573S8089712Y"));
}
finally
{
// Restore original culture
CultureInfo.CurrentCulture = originalCulture;
CultureInfo.CurrentUICulture = originalUICulture;
}
}
}

View File

@@ -2,6 +2,7 @@
using Bit.Billing.Services;
using Bit.Core;
using Bit.Core.Billing.Constants;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
@@ -17,6 +18,9 @@ public class ReconcileAdditionalStorageJobTests
private readonly IStripeFacade _stripeFacade;
private readonly ILogger<ReconcileAdditionalStorageJob> _logger;
private readonly IFeatureService _featureService;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly ReconcileAdditionalStorageJob _sut;
public ReconcileAdditionalStorageJobTests()
@@ -24,7 +28,20 @@ public class ReconcileAdditionalStorageJobTests
_stripeFacade = Substitute.For<IStripeFacade>();
_logger = Substitute.For<ILogger<ReconcileAdditionalStorageJob>>();
_featureService = Substitute.For<IFeatureService>();
_sut = new ReconcileAdditionalStorageJob(_stripeFacade, _logger, _featureService);
_userRepository = Substitute.For<IUserRepository>();
_organizationRepository = Substitute.For<IOrganizationRepository>();
_stripeEventUtilityService = Substitute.For<IStripeEventUtilityService>();
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, null, null));
_sut = new ReconcileAdditionalStorageJob(
_stripeFacade,
_logger,
_featureService,
_userRepository,
_organizationRepository,
_stripeEventUtilityService);
}
#region Feature Flag Tests
@@ -88,6 +105,36 @@ public class ReconcileAdditionalStorageJobTests
await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!);
}
[Fact]
public async Task Execute_DryRunMode_DoesNotUpdateDatabase()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(false); // Dry run ON
// Create a personal subscription that would normally trigger a database update
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10);
subscription.Metadata = new Dictionary<string, string> { ["userId"] = userId.ToString() };
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
// Mock GetIdsFromMetadata to return userId
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
// Act
await _sut.Execute(context);
// Assert - Verify database repositories are never called
await _userRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
await _userRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default!);
await _organizationRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
await _organizationRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default!);
}
[Fact]
public async Task Execute_DryRunModeDisabled_UpdatesSubscriptions()
{
@@ -96,7 +143,11 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); // Dry run OFF
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
@@ -111,6 +162,150 @@ public class ReconcileAdditionalStorageJobTests
Arg.Is<SubscriptionUpdateOptions>(o => o.Items.Count == 1));
}
[Fact]
public async Task Execute_LiveMode_PersonalSubscription_UpdatesUserDatabase()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
// Setup user
var userId = Guid.NewGuid();
var user = new Bit.Core.Entities.User
{
Id = userId,
Email = "test@example.com",
GatewaySubscriptionId = "sub_personal",
MaxStorageGb = 15 // Old value
};
_userRepository.GetByIdAsync(userId).Returns(user);
_userRepository.ReplaceAsync(user).Returns(Task.CompletedTask);
// Create personal subscription with premium seat + 10 GB storage (will be reduced to 6 GB)
var subscription = CreateSubscriptionWithMultipleItems("sub_personal",
[("premium-annually", 1L), ("storage-gb-monthly", 10L)]);
subscription.Metadata = new Dictionary<string, string> { ["userId"] = userId.ToString() };
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Mock GetIdsFromMetadata to return userId
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
// Act
await _sut.Execute(context);
// Assert - Verify Stripe update happened
await _stripeFacade.Received(1).UpdateSubscription(
"sub_personal",
Arg.Is<SubscriptionUpdateOptions>(o => o.Items.Count == 1 && o.Items[0].Quantity == 6));
// Assert - Verify database update with correct MaxStorageGb (5 included + 6 new quantity = 11)
await _userRepository.Received(1).GetByIdAsync(userId);
await _userRepository.Received(1).ReplaceAsync(user);
Assert.Equal((short)11, user.MaxStorageGb);
}
[Fact]
public async Task Execute_LiveMode_OrganizationSubscription_UpdatesOrganizationDatabase()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
// Setup organization
var organizationId = Guid.NewGuid();
var organization = new Bit.Core.AdminConsole.Entities.Organization
{
Id = organizationId,
Name = "Test Organization",
GatewaySubscriptionId = "sub_org",
MaxStorageGb = 13 // Old value
};
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
_organizationRepository.ReplaceAsync(organization).Returns(Task.CompletedTask);
// Create organization subscription with org seat plan + 8 GB storage (will be reduced to 4 GB)
var subscription = CreateSubscriptionWithMultipleItems("sub_org",
[("2023-teams-org-seat-annually", 5L), ("storage-gb-monthly", 8L)]);
subscription.Metadata = new Dictionary<string, string> { ["organizationId"] = organizationId.ToString() };
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Mock GetIdsFromMetadata to return organizationId
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(organizationId, null, null));
// Act
await _sut.Execute(context);
// Assert - Verify Stripe update happened
await _stripeFacade.Received(1).UpdateSubscription(
"sub_org",
Arg.Is<SubscriptionUpdateOptions>(o => o.Items.Count == 1 && o.Items[0].Quantity == 4));
// Assert - Verify database update with correct MaxStorageGb (5 included + 4 new quantity = 9)
await _organizationRepository.Received(1).GetByIdAsync(organizationId);
await _organizationRepository.Received(1).ReplaceAsync(organization);
Assert.Equal((short)9, organization.MaxStorageGb);
}
[Fact]
public async Task Execute_LiveMode_StorageItemDeleted_UpdatesDatabaseWithBaseStorage()
{
// Arrange
var context = CreateJobExecutionContext();
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
// Setup user
var userId = Guid.NewGuid();
var user = new Bit.Core.Entities.User
{
Id = userId,
Email = "test@example.com",
GatewaySubscriptionId = "sub_delete",
MaxStorageGb = 8 // Old value
};
_userRepository.GetByIdAsync(userId).Returns(user);
_userRepository.ReplaceAsync(user).Returns(Task.CompletedTask);
// Create personal subscription with premium seat + 3 GB storage (will be deleted since 3 < 4)
var subscription = CreateSubscriptionWithMultipleItems("sub_delete",
[("premium-annually", 1L), ("storage-gb-monthly", 3L)]);
subscription.Metadata = new Dictionary<string, string> { ["userId"] = userId.ToString() };
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
_stripeFacade.UpdateSubscription(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(subscription);
// Mock GetIdsFromMetadata to return userId
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
// Act
await _sut.Execute(context);
// Assert - Verify Stripe update happened (item deleted)
await _stripeFacade.Received(1).UpdateSubscription(
"sub_delete",
Arg.Is<SubscriptionUpdateOptions>(o => o.Items.Count == 1 && o.Items[0].Deleted == true));
// Assert - Verify database update with base storage only (5 GB)
await _userRepository.Received(1).GetByIdAsync(userId);
await _userRepository.Received(1).ReplaceAsync(user);
Assert.Equal((short)5, user.MaxStorageGb);
}
#endregion
#region Price ID Processing Tests
@@ -174,11 +369,14 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.StorageReconciled2025] = "invalid-date"
};
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: metadata);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -200,7 +398,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: null);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -226,7 +427,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -253,7 +457,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 4);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -279,7 +486,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 2);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -309,7 +519,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -333,7 +546,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -429,9 +645,12 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10);
var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5);
var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3));
@@ -461,6 +680,7 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var processedMetadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o")
@@ -469,6 +689,8 @@ public class ReconcileAdditionalStorageJobTests
var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10);
var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5, metadata: processedMetadata);
var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3));
@@ -501,9 +723,12 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10);
var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5);
var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3));
@@ -563,7 +788,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -585,7 +813,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Trialing);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -607,7 +838,10 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.PastDue);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(subscription));
@@ -669,11 +903,14 @@ public class ReconcileAdditionalStorageJobTests
_featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true);
_featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true);
var userId = Guid.NewGuid();
var activeSubscription = CreateSubscription("sub_active", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active);
var trialingSubscription = CreateSubscription("sub_trialing", "storage-gb-monthly", quantity: 8, status: StripeConstants.SubscriptionStatus.Trialing);
var pastDueSubscription = CreateSubscription("sub_pastdue", "storage-gb-monthly", quantity: 6, status: StripeConstants.SubscriptionStatus.PastDue);
var canceledSubscription = CreateSubscription("sub_canceled", "storage-gb-monthly", quantity: 5, status: StripeConstants.SubscriptionStatus.Canceled);
var incompleteSubscription = CreateSubscription("sub_incomplete", "storage-gb-monthly", quantity: 4, status: StripeConstants.SubscriptionStatus.Incomplete);
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any<SubscriptionListOptions>())
.Returns(AsyncEnumerable.Create(activeSubscription, trialingSubscription, pastDueSubscription, canceledSubscription, incompleteSubscription));
@@ -731,6 +968,410 @@ public class ReconcileAdditionalStorageJobTests
#endregion
#region Helper Method Tests
#region DetermineSubscriptionPlanTier Tests
[Fact]
public void DetermineSubscriptionPlanTier_WithUserId_ReturnsPersonal()
{
// Arrange
var userId = Guid.NewGuid();
Guid? organizationId = null;
// Act
var result = _sut.DetermineSubscriptionPlanTier(userId, organizationId);
// Assert
Assert.Equal(ReconcileAdditionalStorageJob.SubscriptionPlanTier.Personal, result);
}
[Fact]
public void DetermineSubscriptionPlanTier_WithOrganizationId_ReturnsOrganization()
{
// Arrange
Guid? userId = null;
var organizationId = Guid.NewGuid();
// Act
var result = _sut.DetermineSubscriptionPlanTier(userId, organizationId);
// Assert
Assert.Equal(ReconcileAdditionalStorageJob.SubscriptionPlanTier.Organization, result);
}
[Fact]
public void DetermineSubscriptionPlanTier_WithBothIds_ReturnsPersonal()
{
// Arrange - Personal takes precedence
var userId = Guid.NewGuid();
var organizationId = Guid.NewGuid();
// Act
var result = _sut.DetermineSubscriptionPlanTier(userId, organizationId);
// Assert
Assert.Equal(ReconcileAdditionalStorageJob.SubscriptionPlanTier.Personal, result);
}
[Fact]
public void DetermineSubscriptionPlanTier_WithNoIds_ReturnsUnknown()
{
// Arrange
Guid? userId = null;
Guid? organizationId = null;
// Act
var result = _sut.DetermineSubscriptionPlanTier(userId, organizationId);
// Assert
Assert.Equal(ReconcileAdditionalStorageJob.SubscriptionPlanTier.Unknown, result);
}
#endregion
#region GetCurrentStorageQuantityFromSubscription Tests
[Theory]
[InlineData("storage-gb-monthly", 10L, 10L)]
[InlineData("storage-gb-annually", 25L, 25L)]
[InlineData("personal-storage-gb-annually", 5L, 5L)]
[InlineData("storage-gb-monthly", 0L, 0L)]
public void GetCurrentStorageQuantityFromSubscription_WithMatchingPriceId_ReturnsQuantity(
string priceId, long quantity, long expectedQuantity)
{
// Arrange
var subscription = CreateSubscription("sub_123", priceId, quantity);
// Act
var result = _sut.GetCurrentStorageQuantityFromSubscription(subscription, priceId);
// Assert
Assert.Equal(expectedQuantity, result);
}
[Fact]
public void GetCurrentStorageQuantityFromSubscription_WithNonMatchingPriceId_ReturnsZero()
{
// Arrange
var subscription = CreateSubscription("sub_123", "storage-gb-monthly", 10L);
// Act
var result = _sut.GetCurrentStorageQuantityFromSubscription(subscription, "different-price-id");
// Assert
Assert.Equal(0, result);
}
[Fact]
public void GetCurrentStorageQuantityFromSubscription_WithNullItems_ReturnsZero()
{
// Arrange
var subscription = new Subscription { Id = "sub_123", Items = null };
// Act
var result = _sut.GetCurrentStorageQuantityFromSubscription(subscription, "storage-gb-monthly");
// Assert
Assert.Equal(0, result);
}
[Fact]
public void GetCurrentStorageQuantityFromSubscription_WithEmptyItems_ReturnsZero()
{
// Arrange
var subscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem> { Data = [] }
};
// Act
var result = _sut.GetCurrentStorageQuantityFromSubscription(subscription, "storage-gb-monthly");
// Assert
Assert.Equal(0, result);
}
#endregion
#region CalculateNewMaxStorageGb Tests
[Theory]
[InlineData(10L, 6L, 11)] // 5 included + 6 new quantity
[InlineData(15L, 11L, 16)] // 5 included + 11 new quantity
[InlineData(4L, 0L, 5)] // Item deleted, returns base storage
[InlineData(2L, 0L, 5)] // Item deleted, returns base storage
[InlineData(8L, 4L, 9)] // 5 included + 4 new quantity
public void CalculateNewMaxStorageGb_WithQuantityUpdate_ReturnsCorrectMaxStorage(
long currentQuantity, long newQuantity, short expectedMaxStorageGb)
{
// Arrange
var updateOptions = new SubscriptionUpdateOptions
{
Items =
[
newQuantity == 0
? new SubscriptionItemOptions { Id = "si_123", Deleted = true } // Item marked as deleted
: new SubscriptionItemOptions { Id = "si_123", Quantity = newQuantity } // Item quantity updated
]
};
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, updateOptions);
// Assert
Assert.Equal(expectedMaxStorageGb, result);
}
[Fact]
public void CalculateNewMaxStorageGb_WithNullUpdateOptions_ReturnsCurrentQuantityPlusBaseIncluded()
{
// Arrange
const long currentQuantity = 10;
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, null);
// Assert
Assert.Equal((short)(5 + currentQuantity), result);
}
[Fact]
public void CalculateNewMaxStorageGb_WithNullItems_ReturnsCurrentQuantityPlusBaseIncluded()
{
// Arrange
const long currentQuantity = 10;
var updateOptions = new SubscriptionUpdateOptions { Items = null };
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, updateOptions);
// Assert
Assert.Equal(5 + currentQuantity, result);
}
[Fact]
public void CalculateNewMaxStorageGb_WithEmptyItems_ReturnsCurrentQuantity()
{
// Arrange
const long currentQuantity = 10;
var updateOptions = new SubscriptionUpdateOptions
{
Items = []
};
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, updateOptions);
// Assert
Assert.Equal(5 + currentQuantity, result);
}
[Fact]
public void CalculateNewMaxStorageGb_WithDeletedItem_ReturnsBaseStorage()
{
// Arrange
const long currentQuantity = 100;
var updateOptions = new SubscriptionUpdateOptions
{
Items = [new SubscriptionItemOptions { Id = "si_123", Deleted = true }]
};
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, updateOptions);
// Assert
Assert.Equal((short)5, result); // Base storage
}
[Fact]
public void CalculateNewMaxStorageGb_WithItemWithoutQuantity_ReturnsCurrentQuantity()
{
// Arrange
const long currentQuantity = 10;
var updateOptions = new SubscriptionUpdateOptions
{
Items = [new SubscriptionItemOptions { Id = "si_123", Quantity = null }]
};
// Act
var result = _sut.CalculateNewMaxStorageGb(currentQuantity, updateOptions);
// Assert
Assert.Equal(5 + currentQuantity, result);
}
#endregion
#region UpdateDatabaseMaxStorageAsync Tests
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_PersonalTier_UpdatesUser()
{
// Arrange
var userId = Guid.NewGuid();
var user = new Bit.Core.Entities.User
{
Id = userId,
Email = "test@example.com",
GatewaySubscriptionId = "sub_123"
};
_userRepository.GetByIdAsync(userId).Returns(user);
_userRepository.ReplaceAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Personal,
userId,
10,
"sub_123");
// Assert
Assert.True(result);
Assert.Equal((short)10, user.MaxStorageGb);
await _userRepository.Received(1).GetByIdAsync(userId);
await _userRepository.Received(1).ReplaceAsync(user);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_PersonalTier_UserNotFound_ReturnsFalse()
{
// Arrange
var userId = Guid.NewGuid();
_userRepository.GetByIdAsync(userId).Returns((Bit.Core.Entities.User?)null);
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Personal,
userId,
10,
"sub_123");
// Assert
Assert.False(result);
await _userRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default!);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_PersonalTier_ReplaceThrowsException_ReturnsFalse()
{
// Arrange
var userId = Guid.NewGuid();
var user = new Bit.Core.Entities.User
{
Id = userId,
Email = "test@example.com",
GatewaySubscriptionId = "sub_123"
};
_userRepository.GetByIdAsync(userId).Returns(user);
_userRepository.ReplaceAsync(user).Throws(new Exception("Database error"));
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Personal,
userId,
10,
"sub_123");
// Assert
Assert.False(result);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_OrganizationTier_UpdatesOrganization()
{
// Arrange
var organizationId = Guid.NewGuid();
var organization = new Bit.Core.AdminConsole.Entities.Organization
{
Id = organizationId,
Name = "Test Org",
GatewaySubscriptionId = "sub_456"
};
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
_organizationRepository.ReplaceAsync(organization).Returns(Task.CompletedTask);
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Organization,
organizationId,
20,
"sub_456");
// Assert
Assert.True(result);
Assert.Equal((short)20, organization.MaxStorageGb);
await _organizationRepository.Received(1).GetByIdAsync(organizationId);
await _organizationRepository.Received(1).ReplaceAsync(organization);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_OrganizationTier_OrganizationNotFound_ReturnsFalse()
{
// Arrange
var organizationId = Guid.NewGuid();
_organizationRepository.GetByIdAsync(organizationId)
.Returns((Bit.Core.AdminConsole.Entities.Organization?)null);
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Organization,
organizationId,
20,
"sub_456");
// Assert
Assert.False(result);
await _organizationRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default!);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_OrganizationTier_ReplaceThrowsException_ReturnsFalse()
{
// Arrange
var organizationId = Guid.NewGuid();
var organization = new Bit.Core.AdminConsole.Entities.Organization
{
Id = organizationId,
Name = "Test Org",
GatewaySubscriptionId = "sub_456"
};
_organizationRepository.GetByIdAsync(organizationId).Returns(organization);
_organizationRepository.ReplaceAsync(organization).Throws(new Exception("Database error"));
// Act
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Organization,
organizationId,
20,
"sub_456");
// Assert
Assert.False(result);
}
[Fact]
public async Task UpdateDatabaseMaxStorageAsync_UnknownTier_ReturnsFalse()
{
// Arrange & Act
var entityId = Guid.NewGuid();
var result = await _sut.UpdateDatabaseMaxStorageAsync(
ReconcileAdditionalStorageJob.SubscriptionPlanTier.Unknown,
entityId,
15,
"sub_789");
// Assert
Assert.False(result);
await _userRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
await _organizationRepository.DidNotReceiveWithAnyArgs().GetByIdAsync(default);
}
#endregion
#endregion
#region Helper Methods
private static IJobExecutionContext CreateJobExecutionContext(CancellationToken cancellationToken = default)
@@ -762,7 +1403,27 @@ public class ReconcileAdditionalStorageJobTests
Metadata = metadata,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem> { item }
Data = [item]
}
};
}
private static Subscription CreateSubscriptionWithMultipleItems(string id, (string priceId, long quantity)[] items)
{
var subscriptionItems = items.Select(i => new SubscriptionItem
{
Id = $"si_{id}_{i.priceId}",
Price = new Price { Id = i.priceId },
Quantity = i.quantity
}).ToList();
return new Subscription
{
Id = id,
Status = StripeConstants.SubscriptionStatus.Active,
Items = new StripeList<SubscriptionItem>
{
Data = subscriptionItems
}
};
}

View File

@@ -26,6 +26,7 @@ public class SutProvider<TSut> : ISutProvider
public TSut Sut { get; private set; }
public Type SutType => typeof(TSut);
public IFixture Fixture => _fixture;
public SutProvider() : this(new Fixture()) { }
@@ -65,6 +66,19 @@ public class SutProvider<TSut> : ISutProvider
return this;
}
/// <summary>
/// Creates and registers a dependency to be injected when the sut is created.
/// </summary>
/// <typeparam name="TDep">The Dependency type to create</typeparam>
/// <param name="parameterName">The (optional) parameter name to register the dependency under</param>
/// <returns>The created dependency value</returns>
public TDep CreateDependency<TDep>(string parameterName = "")
{
var dependency = _fixture.Create<TDep>();
SetDependency(dependency, parameterName);
return dependency;
}
/// <summary>
/// Gets a dependency of the sut. Can only be called after the dependency has been set, either explicitly with
/// <see cref="SetDependency{T}"/> or automatically with <see cref="Create"/>.

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<RootNamespace>Bit.Test.Common</RootNamespace>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,7 +1,10 @@
using System.Text.Json;
using System.Reflection;
using System.Text.Json;
using System.Text.RegularExpressions;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Billing.Organizations.Models;
using Bit.Test.Common.Helpers;
using Xunit;
@@ -96,4 +99,124 @@ public class OrganizationTests
var host = Assert.Contains("Host", (IDictionary<string, object>)duo.MetaData);
Assert.Equal("Host_value", host);
}
[Fact]
public void UseDisableSmAdsForUsers_DefaultValue_IsFalse()
{
var organization = new Organization();
Assert.False(organization.UseDisableSmAdsForUsers);
}
[Fact]
public void UseDisableSmAdsForUsers_CanBeSetToTrue()
{
var organization = new Organization
{
UseDisableSmAdsForUsers = true
};
Assert.True(organization.UseDisableSmAdsForUsers);
}
[Fact]
public void UpdateFromLicense_AppliesAllLicenseProperties()
{
// This test ensures that when a new property is added to OrganizationLicense,
// it is also applied to the Organization in UpdateFromLicense().
// This is the fourth step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Get all public properties from OrganizationLicense
var licenseProperties = typeof(OrganizationLicense)
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Select(p => p.Name)
.ToHashSet();
// 2. Define properties that don't need to be applied to Organization
var excludedProperties = new HashSet<string>
{
// Internal/computed properties
"SignatureBytes", // Computed from Signature property
"ValidLicenseVersion", // Internal property, not serialized
"CurrentLicenseFileVersion", // Constant field, not an instance property
"Hash", // Signature-related, not applied to org
"Signature", // Signature-related, not applied to org
"Token", // The JWT itself, not applied to org
"Version", // License version, not stored on org
// Properties intentionally excluded from UpdateFromLicense
"Id", // Self-hosted org has its own unique Guid
"MaxStorageGb", // Not enforced for self-hosted (per comment in UpdateFromLicense)
// Properties not stored on Organization model
"LicenseType", // Not a property on Organization
"InstallationId", // Not a property on Organization
"Issued", // Not a property on Organization
"Refresh", // Not a property on Organization
"ExpirationWithoutGracePeriod", // Not a property on Organization
"Trial", // Not a property on Organization
"Expires", // Mapped to ExpirationDate on Organization (different name)
// Deprecated properties not applied
"LimitCollectionCreationDeletion", // Deprecated, not applied
"AllowAdminAccessToAllCollectionItems", // Deprecated, not applied
};
// 3. Get properties that should be applied
var propertiesThatShouldBeApplied = licenseProperties
.Except(excludedProperties)
.ToHashSet();
// 4. Read Organization.UpdateFromLicense source code
var organizationSourcePath = Path.Combine(
Directory.GetCurrentDirectory(),
"..", "..", "..", "..", "..", "src", "Core", "AdminConsole", "Entities", "Organization.cs");
var sourceCode = File.ReadAllText(organizationSourcePath);
// 5. Find all property assignments in UpdateFromLicense method
// Pattern matches: PropertyName = license.PropertyName
// This regex looks for assignments like "Name = license.Name" or "ExpirationDate = license.Expires"
var assignmentPattern = @"(\w+)\s*=\s*license\.(\w+)";
var matches = Regex.Matches(sourceCode, assignmentPattern);
var appliedProperties = new HashSet<string>();
foreach (Match match in matches)
{
// Get the license property name (right side of assignment)
var licensePropertyName = match.Groups[2].Value;
appliedProperties.Add(licensePropertyName);
}
// Special case: Expires is mapped to ExpirationDate
if (appliedProperties.Contains("Expires"))
{
appliedProperties.Add("Expires"); // Already added, but being explicit
}
// 6. Find missing applications
var missingApplications = propertiesThatShouldBeApplied
.Except(appliedProperties)
.OrderBy(p => p)
.ToList();
// 7. Build error message with guidance
var errorMessage = "";
if (missingApplications.Any())
{
errorMessage = $"The following OrganizationLicense properties are NOT applied to Organization in UpdateFromLicense():\n";
errorMessage += string.Join("\n", missingApplications.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following lines to Organization.UpdateFromLicense():\n";
foreach (var prop in missingApplications)
{
errorMessage += $" {prop} = license.{prop};\n";
}
errorMessage += "\nNote: If the property maps to a different name on Organization (like Expires → ExpirationDate), adjust accordingly.";
}
// 8. Assert - if this fails, the error message guides the developer to add the application
Assert.True(
!missingApplications.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -240,6 +240,6 @@ public class SendOrganizationConfirmationCommandTests
}
}
private static string GetSubject(string organizationName) => $"You Have Been Confirmed To {organizationName}";
private static string GetSubject(string organizationName) => $"You can now access items from {organizationName}";
}

View File

@@ -283,7 +283,7 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Invited,
UserId = Guid.NewGuid(),
UserId = null,
Email = "invited@example.com"
};
@@ -302,6 +302,56 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests
Assert.True(string.IsNullOrEmpty(result));
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_MixedUsersWithNullUserId_HandlesCorrectly(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,
Guid confirmedUserId,
SutProvider<AutomaticUserConfirmationPolicyEventHandler> sutProvider)
{
// Arrange
var invitedUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Invited,
UserId = null,
Email = "invited@example.com"
};
var confirmedUser = new OrganizationUserUserDetails
{
Id = Guid.NewGuid(),
OrganizationId = policyUpdate.OrganizationId,
Type = OrganizationUserType.User,
Status = OrganizationUserStatusType.Confirmed,
UserId = confirmedUserId,
Email = "confirmed@example.com"
};
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId)
.Returns([invitedUser, confirmedUser]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([]);
sutProvider.GetDependency<IProviderUserRepository>()
.GetManyByManyUsersAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([]);
// Act
var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null);
// Assert
Assert.True(string.IsNullOrEmpty(result));
await sutProvider.GetDependency<IOrganizationUserRepository>()
.Received(1)
.GetManyByManyUsersAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 1 && ids.First() == confirmedUserId));
}
[Theory, BitAutoData]
public async Task ValidateAsync_EnablingPolicy_RevokedUsersIncluded_InComplianceCheck(
[PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate,

View File

@@ -1,8 +1,10 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.UserMasterPassword;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -21,106 +23,154 @@ public class SetInitialMasterPasswordCommandTests
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_Success(SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, string masterPassword, string key, string orgIdentifier,
Organization org, OrganizationUser orgUser)
User user, UserAccountKeysData accountKeys, KdfSettings kdfSettings,
Organization org, OrganizationUser orgUser, string serverSideHash, string masterPasswordHint)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
user.Key = null;
var model = CreateValidModel(user, accountKeys, kdfSettings, org.Identifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.GetByIdentifierAsync(org.Identifier)
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
sutProvider.GetDependency<IPasswordHasher<User>>()
.HashPassword(user, model.MasterPasswordAuthentication.MasterPasswordAuthenticationHash)
.Returns(serverSideHash);
// Assert
Assert.Equal(IdentityResult.Success, result);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_UserIsNull_ThrowsArgumentNullException(SutProvider<SetInitialMasterPasswordCommand> sutProvider, string masterPassword, string key, string orgIdentifier)
{
// Act & Assert
await Assert.ThrowsAsync<ArgumentNullException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(null, masterPassword, key, orgIdentifier));
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_AlreadyHasPassword_ReturnsFalse(SutProvider<SetInitialMasterPasswordCommand> sutProvider, User user, string masterPassword, string key, string orgIdentifier)
{
// Arrange
user.MasterPassword = "ExistingPassword";
// Mock SetMasterPassword to return a specific UpdateUserData delegate
UpdateUserData mockUpdateUserData = (connection, transaction) => Task.CompletedTask;
sutProvider.GetDependency<IUserRepository>()
.SetMasterPassword(user.Id, model.MasterPasswordUnlock, serverSideHash, model.MasterPasswordHint)
.Returns(mockUpdateUserData);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model);
// Assert
Assert.False(result.Succeeded);
await sutProvider.GetDependency<IUserRepository>().Received(1)
.SetV2AccountCryptographicStateAsync(
user.Id,
model.AccountKeys,
Arg.Do<IEnumerable<UpdateUserData>>(actions =>
{
var actionsList = actions.ToList();
Assert.Single(actionsList);
Assert.Same(mockUpdateUserData, actionsList[0]);
}));
await sutProvider.GetDependency<IEventService>().Received(1)
.LogUserEventAsync(user.Id, EventType.User_ChangedPassword);
await sutProvider.GetDependency<IAcceptOrgUserCommand>().Received(1)
.AcceptOrgUserAsync(orgUser, user, sutProvider.GetDependency<IUserService>());
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_NullOrgSsoIdentifier_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommand> sutProvider, User user, string masterPassword, string key)
public async Task SetInitialMasterPassword_UserAlreadyHasPassword_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, UserAccountKeysData accountKeys, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.MasterPassword = null;
string orgSsoIdentifier = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
user.Key = "existing-key";
var model = CreateValidModel(user, accountKeys, kdfSettings, orgSsoIdentifier, masterPasswordHint);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgSsoIdentifier));
Assert.Equal("Organization SSO Identifier required.", exception.Message);
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model));
Assert.Equal("User already has a master password set.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_InvalidOrganization_Throws(SutProvider<SetInitialMasterPasswordCommand> sutProvider, User user, string masterPassword, string key, string orgIdentifier)
public async Task SetInitialMasterPassword_AccountKeysNull_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.MasterPassword = null;
user.Key = null;
var model = CreateValidModel(user, null, kdfSettings, orgSsoIdentifier, masterPasswordHint);
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model));
Assert.Equal("Account keys are required.", exception.Message);
}
[Theory]
[BitAutoData("wrong-salt", null)]
[BitAutoData([null, "wrong-salt"])]
[BitAutoData("wrong-salt", "different-wrong-salt")]
public async Task SetInitialMasterPassword_InvalidSalt_ThrowsBadRequestException(
string? authSaltOverride, string? unlockSaltOverride,
SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, UserAccountKeysData accountKeys, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = null;
var correctSalt = user.GetMasterPasswordSalt();
var model = new SetInitialMasterPasswordDataModel
{
MasterPasswordAuthentication = new MasterPasswordAuthenticationData
{
Salt = authSaltOverride ?? correctSalt,
MasterPasswordAuthenticationHash = "hash",
Kdf = kdfSettings
},
MasterPasswordUnlock = new MasterPasswordUnlockData
{
Salt = unlockSaltOverride ?? correctSalt,
MasterKeyWrappedUserKey = "wrapped-key",
Kdf = kdfSettings
},
AccountKeys = accountKeys,
OrgSsoIdentifier = orgSsoIdentifier,
MasterPasswordHint = masterPasswordHint
};
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model));
Assert.Equal("Invalid master password salt.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_InvalidOrgSsoIdentifier_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, UserAccountKeysData accountKeys, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = null;
var model = CreateValidModel(user, accountKeys, kdfSettings, orgSsoIdentifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.GetByIdentifierAsync(orgSsoIdentifier)
.ReturnsNull();
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier));
Assert.Equal("Organization invalid.", exception.Message);
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model));
Assert.Equal("Organization SSO identifier is invalid.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_UserNotFoundInOrganization_Throws(SutProvider<SetInitialMasterPasswordCommand> sutProvider, User user, string masterPassword, string key, Organization org)
public async Task SetInitialMasterPassword_UserNotFoundInOrganization_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, UserAccountKeysData accountKeys, KdfSettings kdfSettings, Organization org, string masterPasswordHint)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
user.Key = null;
var model = CreateValidModel(user, accountKeys, kdfSettings, org.Identifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(Arg.Any<string>())
.GetByIdentifierAsync(org.Identifier)
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
@@ -128,67 +178,33 @@ public class SetInitialMasterPasswordCommandTests
.ReturnsNull();
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, org.Identifier));
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, model));
Assert.Equal("User not found within organization.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_ConfirmedOrgUser_DoesNotCallAcceptOrgUser(SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, string masterPassword, string key, string orgIdentifier, Organization org, OrganizationUser orgUser)
private static SetInitialMasterPasswordDataModel CreateValidModel(
User user, UserAccountKeysData? accountKeys, KdfSettings kdfSettings,
string orgSsoIdentifier, string? masterPasswordHint)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.Returns(org);
orgUser.Status = OrganizationUserStatusType.Confirmed;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.Equal(IdentityResult.Success, result);
await sutProvider.GetDependency<IAcceptOrgUserCommand>().DidNotReceive().AcceptOrgUserAsync(Arg.Any<OrganizationUser>(), Arg.Any<User>(), Arg.Any<IUserService>());
var salt = user.GetMasterPasswordSalt();
return new SetInitialMasterPasswordDataModel
{
MasterPasswordAuthentication = new MasterPasswordAuthenticationData
{
Salt = salt,
MasterPasswordAuthenticationHash = "hash",
Kdf = kdfSettings
},
MasterPasswordUnlock = new MasterPasswordUnlockData
{
Salt = salt,
MasterKeyWrappedUserKey = "wrapped-key",
Kdf = kdfSettings
},
AccountKeys = accountKeys,
OrgSsoIdentifier = orgSsoIdentifier,
MasterPasswordHint = masterPasswordHint
};
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_InvitedOrgUser_CallsAcceptOrgUser(SutProvider<SetInitialMasterPasswordCommand> sutProvider,
User user, string masterPassword, string key, string orgIdentifier, Organization org, OrganizationUser orgUser)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.Returns(org);
orgUser.Status = OrganizationUserStatusType.Invited;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.Equal(IdentityResult.Success, result);
await sutProvider.GetDependency<IAcceptOrgUserCommand>().Received(1).AcceptOrgUserAsync(orgUser, user, sutProvider.GetDependency<IUserService>());
}
}

View File

@@ -0,0 +1,194 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.UserFeatures.UserMasterPassword;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
namespace Bit.Core.Test.Auth.UserFeatures.UserMasterPassword;
[SutProviderCustomize]
public class SetInitialMasterPasswordCommandV1Tests
{
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_Success(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider,
User user, string masterPassword, string key, string orgIdentifier,
Organization org, OrganizationUser orgUser)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.Equal(IdentityResult.Success, result);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_UserIsNull_ThrowsArgumentNullException(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider, string masterPassword, string key, string orgIdentifier)
{
// Act & Assert
await Assert.ThrowsAsync<ArgumentNullException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(null, masterPassword, key, orgIdentifier));
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_AlreadyHasPassword_ReturnsFalse(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider, User user, string masterPassword, string key, string orgIdentifier)
{
// Arrange
user.MasterPassword = "ExistingPassword";
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.False(result.Succeeded);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_NullOrgSsoIdentifier_ThrowsBadRequestException(
SutProvider<SetInitialMasterPasswordCommandV1> sutProvider, User user, string masterPassword, string key)
{
// Arrange
user.MasterPassword = null;
string orgSsoIdentifier = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgSsoIdentifier));
Assert.Equal("Organization SSO Identifier required.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_InvalidOrganization_Throws(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider, User user, string masterPassword, string key, string orgIdentifier)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.ReturnsNull();
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier));
Assert.Equal("Organization invalid.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_UserNotFoundInOrganization_Throws(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider, User user, string masterPassword, string key, Organization org)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(Arg.Any<string>())
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.ReturnsNull();
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, org.Identifier));
Assert.Equal("User not found within organization.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_ConfirmedOrgUser_DoesNotCallAcceptOrgUser(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider,
User user, string masterPassword, string key, string orgIdentifier, Organization org, OrganizationUser orgUser)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.Returns(org);
orgUser.Status = OrganizationUserStatusType.Confirmed;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.Equal(IdentityResult.Success, result);
await sutProvider.GetDependency<IAcceptOrgUserCommand>().DidNotReceive().AcceptOrgUserAsync(Arg.Any<OrganizationUser>(), Arg.Any<User>(), Arg.Any<IUserService>());
}
[Theory]
[BitAutoData]
public async Task SetInitialMasterPassword_InvitedOrgUser_CallsAcceptOrgUser(SutProvider<SetInitialMasterPasswordCommandV1> sutProvider,
User user, string masterPassword, string key, string orgIdentifier, Organization org, OrganizationUser orgUser)
{
// Arrange
user.MasterPassword = null;
sutProvider.GetDependency<IUserService>()
.UpdatePasswordHash(Arg.Any<User>(), Arg.Any<string>(), true, false)
.Returns(IdentityResult.Success);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgIdentifier)
.Returns(org);
orgUser.Status = OrganizationUserStatusType.Invited;
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
// Act
var result = await sutProvider.Sut.SetInitialMasterPasswordAsync(user, masterPassword, key, orgIdentifier);
// Assert
Assert.Equal(IdentityResult.Success, result);
await sutProvider.GetDependency<IAcceptOrgUserCommand>().Received(1).AcceptOrgUserAsync(orgUser, user, sutProvider.GetDependency<IUserService>());
}
}

View File

@@ -0,0 +1,223 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.UserMasterPassword;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
namespace Bit.Core.Test.Auth.UserFeatures.UserMasterPassword;
[SutProviderCustomize]
public class TdeSetPasswordCommandTests
{
[Theory]
[BitAutoData]
public async Task OnboardMasterPassword_Success(SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings,
Organization org, OrganizationUser orgUser, string serverSideHash, string masterPasswordHint)
{
// Arrange
user.Key = null;
user.PublicKey = "public-key";
user.PrivateKey = "private-key";
var model = CreateValidModel(user, kdfSettings, org.Identifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(org.Identifier)
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.Returns(orgUser);
sutProvider.GetDependency<IPasswordHasher<User>>()
.HashPassword(user, model.MasterPasswordAuthentication.MasterPasswordAuthenticationHash)
.Returns(serverSideHash);
// Mock SetMasterPassword to return a specific UpdateUserData delegate
UpdateUserData mockUpdateUserData = (connection, transaction) => Task.CompletedTask;
sutProvider.GetDependency<IUserRepository>()
.SetMasterPassword(user.Id, model.MasterPasswordUnlock, serverSideHash, model.MasterPasswordHint)
.Returns(mockUpdateUserData);
// Act
await sutProvider.Sut.SetMasterPasswordAsync(user, model);
// Assert
await sutProvider.GetDependency<IUserRepository>().Received(1)
.UpdateUserDataAsync(Arg.Do<IEnumerable<UpdateUserData>>(actions =>
{
var actionsList = actions.ToList();
Assert.Single(actionsList);
Assert.Same(mockUpdateUserData, actionsList[0]);
}));
await sutProvider.GetDependency<IEventService>().Received(1)
.LogUserEventAsync(user.Id, EventType.User_ChangedPassword);
}
[Theory]
[BitAutoData]
public async Task OnboardMasterPassword_UserAlreadyHasPassword_ThrowsBadRequestException(
SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = "existing-key";
var model = CreateValidModel(user, kdfSettings, orgSsoIdentifier, masterPasswordHint);
// Act & Assert
var exception =
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetMasterPasswordAsync(user, model));
Assert.Equal("User already has a master password set.", exception.Message);
}
[Theory]
[BitAutoData([null, "private-key"])]
[BitAutoData("public-key", null)]
[BitAutoData([null, null])]
public async Task OnboardMasterPassword_MissingAccountKeys_ThrowsBadRequestException(
string? publicKey, string? privateKey,
SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = null;
user.PublicKey = publicKey;
user.PrivateKey = privateKey;
var model = CreateValidModel(user, kdfSettings, orgSsoIdentifier, masterPasswordHint);
// Act & Assert
var exception =
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetMasterPasswordAsync(user, model));
Assert.Equal("TDE user account keys must be set before setting initial master password.", exception.Message);
}
[Theory]
[BitAutoData("wrong-salt", null)]
[BitAutoData([null, "wrong-salt"])]
[BitAutoData("wrong-salt", "different-wrong-salt")]
public async Task OnboardMasterPassword_InvalidSalt_ThrowsBadRequestException(
string? authSaltOverride, string? unlockSaltOverride,
SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = null;
user.PublicKey = "public-key";
user.PrivateKey = "private-key";
var correctSalt = user.GetMasterPasswordSalt();
var model = new SetInitialMasterPasswordDataModel
{
MasterPasswordAuthentication =
new MasterPasswordAuthenticationData
{
Salt = authSaltOverride ?? correctSalt,
MasterPasswordAuthenticationHash = "hash",
Kdf = kdfSettings
},
MasterPasswordUnlock = new MasterPasswordUnlockData
{
Salt = unlockSaltOverride ?? correctSalt,
MasterKeyWrappedUserKey = "wrapped-key",
Kdf = kdfSettings
},
AccountKeys = null,
OrgSsoIdentifier = orgSsoIdentifier,
MasterPasswordHint = masterPasswordHint
};
// Act & Assert
var exception =
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetMasterPasswordAsync(user, model));
Assert.Equal("Invalid master password salt.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task OnboardMasterPassword_InvalidOrgSsoIdentifier_ThrowsBadRequestException(
SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string masterPasswordHint)
{
// Arrange
user.Key = null;
user.PublicKey = "public-key";
user.PrivateKey = "private-key";
var model = CreateValidModel(user, kdfSettings, orgSsoIdentifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(orgSsoIdentifier)
.ReturnsNull();
// Act & Assert
var exception =
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetMasterPasswordAsync(user, model));
Assert.Equal("Organization SSO identifier is invalid.", exception.Message);
}
[Theory]
[BitAutoData]
public async Task OnboardMasterPassword_UserNotFoundInOrganization_ThrowsBadRequestException(
SutProvider<TdeSetPasswordCommand> sutProvider,
User user, KdfSettings kdfSettings, Organization org, string masterPasswordHint)
{
// Arrange
user.Key = null;
user.PublicKey = "public-key";
user.PrivateKey = "private-key";
var model = CreateValidModel(user, kdfSettings, org.Identifier, masterPasswordHint);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdentifierAsync(org.Identifier)
.Returns(org);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetByOrganizationAsync(org.Id, user.Id)
.ReturnsNull();
// Act & Assert
var exception =
await Assert.ThrowsAsync<BadRequestException>(async () =>
await sutProvider.Sut.SetMasterPasswordAsync(user, model));
Assert.Equal("User not found within organization.", exception.Message);
}
private static SetInitialMasterPasswordDataModel CreateValidModel(
User user, KdfSettings kdfSettings, string orgSsoIdentifier, string? masterPasswordHint)
{
var salt = user.GetMasterPasswordSalt();
return new SetInitialMasterPasswordDataModel
{
MasterPasswordAuthentication =
new MasterPasswordAuthenticationData
{
Salt = salt,
MasterPasswordAuthenticationHash = "hash",
Kdf = kdfSettings
},
MasterPasswordUnlock =
new MasterPasswordUnlockData
{
Salt = salt,
MasterKeyWrappedUserKey = "wrapped-key",
Kdf = kdfSettings
},
AccountKeys = null,
OrgSsoIdentifier = orgSsoIdentifier,
MasterPasswordHint = masterPasswordHint
};
}
}

View File

@@ -0,0 +1,68 @@
using System.Reflection;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Organizations.Models;
using Xunit;
namespace Bit.Core.Test.Billing.Licenses;
public class LicenseConstantsTests
{
[Fact]
public void OrganizationLicenseConstants_HasConstantForEveryLicenseProperty()
{
// This test ensures that when a new property is added to OrganizationLicense,
// a corresponding constant is added to OrganizationLicenseConstants.
// This is the first step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Get all public properties from OrganizationLicense
var licenseProperties = typeof(OrganizationLicense)
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Select(p => p.Name)
.ToHashSet();
// 2. Get all constants from OrganizationLicenseConstants
var constants = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToHashSet();
// 3. Define properties that don't need constants (internal/computed/non-claims properties)
var excludedProperties = new HashSet<string>
{
"SignatureBytes", // Computed from Signature property
"ValidLicenseVersion", // Internal property, not serialized
"CurrentLicenseFileVersion", // Constant field, not an instance property
"Hash", // Signature-related, not in claims system
"Signature", // Signature-related, not in claims system
"Token", // The JWT itself, not a claim within the token
"Version" // Not in claims system (only in deprecated property-based licenses)
};
// 4. Find license properties without corresponding constants
var propertiesWithoutConstants = licenseProperties
.Except(constants)
.Except(excludedProperties)
.OrderBy(p => p)
.ToList();
// 5. Build error message with guidance
var errorMessage = "";
if (propertiesWithoutConstants.Any())
{
errorMessage = $"The following OrganizationLicense properties don't have constants in OrganizationLicenseConstants:\n";
errorMessage += string.Join("\n", propertiesWithoutConstants.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following constants to OrganizationLicenseConstants:\n";
foreach (var prop in propertiesWithoutConstants)
{
errorMessage += $" public const string {prop} = nameof({prop});\n";
}
}
// 6. Assert - if this fails, the error message guides the developer to add the constant
Assert.True(
!propertiesWithoutConstants.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -0,0 +1,92 @@
using System.Reflection;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Models;
using Bit.Core.Billing.Licenses.Services.Implementations;
using Bit.Core.Models.Business;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Core.Test.Billing.Licenses.Services.Implementations;
public class OrganizationLicenseClaimsFactoryTests
{
[Theory, BitAutoData]
public async Task GenerateClaims_CreatesClaimsForAllConstants(Organization organization)
{
// This test ensures that when a constant is added to OrganizationLicenseConstants,
// it is also added to the OrganizationLicenseClaimsFactory to generate claims.
// This is the second step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Populate all nullable properties to ensure claims can be generated
// The factory only adds claims for properties that have values
organization.Name = "Test Organization";
organization.BillingEmail = "billing@test.com";
organization.BusinessName = "Test Business";
organization.Plan = "Enterprise";
organization.LicenseKey = "test-license-key";
organization.Seats = 100;
organization.MaxCollections = 50;
organization.MaxStorageGb = 10;
organization.SmSeats = 25;
organization.SmServiceAccounts = 10;
organization.ExpirationDate = DateTime.UtcNow.AddYears(1); // Ensure org is not expired
// Create a LicenseContext with a minimal SubscriptionInfo to trigger conditional claims
// ExpirationWithoutGracePeriod is only generated for active, non-trial, annual subscriptions
var licenseContext = new LicenseContext
{
InstallationId = Guid.NewGuid(),
SubscriptionInfo = new SubscriptionInfo
{
Subscription = new SubscriptionInfo.BillingSubscription(null!)
{
TrialEndDate = DateTime.UtcNow.AddDays(-30), // Trial ended in the past
PeriodStartDate = DateTime.UtcNow,
PeriodEndDate = DateTime.UtcNow.AddDays(365), // Annual subscription (>180 days)
Status = "active"
}
}
};
// 2. Generate claims
var factory = new OrganizationLicenseClaimsFactory();
var claims = await factory.GenerateClaims(organization, licenseContext);
// 3. Get all constants from OrganizationLicenseConstants
var allConstants = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToHashSet();
// 4. Get claim types from generated claims
var generatedClaimTypes = claims.Select(c => c.Type).ToHashSet();
// 5. Find constants that don't have corresponding claims
var constantsWithoutClaims = allConstants
.Except(generatedClaimTypes)
.OrderBy(c => c)
.ToList();
// 6. Build error message with guidance
var errorMessage = "";
if (constantsWithoutClaims.Any())
{
errorMessage = $"The following constants in OrganizationLicenseConstants are NOT generated as claims in OrganizationLicenseClaimsFactory:\n";
errorMessage += string.Join("\n", constantsWithoutClaims.Select(c => $" - {c}"));
errorMessage += "\n\nPlease add the following claims to OrganizationLicenseClaimsFactory.GenerateClaims():\n";
foreach (var constant in constantsWithoutClaims)
{
errorMessage += $" new(nameof(OrganizationLicenseConstants.{constant}), entity.{constant}.ToString()),\n";
}
errorMessage += "\nNote: If the property is nullable, you may need to add it conditionally.";
}
// 7. Assert - if this fails, the error message guides the developer to add claim generation
Assert.True(
!constantsWithoutClaims.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -214,6 +214,7 @@ If you believe you need to change the version for a valid reason, please discuss
AllowAdminAccessToAllCollectionItems = true,
UseOrganizationDomains = true,
UseAdminSponsoredFamilies = false,
UseDisableSmAdsForUsers = false,
UsePhishingBlocker = false,
};
}
@@ -260,4 +261,34 @@ If you believe you need to change the version for a valid reason, please discuss
.Returns([0x00, 0x01, 0x02, 0x03]); // Dummy signature for hash testing
return mockService;
}
/// <summary>
/// Verifies that UseDisableSmAdsForUsers claim is properly generated in the license Token
/// and that VerifyData correctly validates the claim.
/// </summary>
[Theory]
[BitAutoData(true)]
[BitAutoData(false)]
public void OrganizationLicense_UseDisableSmAdsForUsers_ClaimGenerationAndValidation(bool useDisableSmAdsForUsers, ClaimsPrincipal claimsPrincipal)
{
// Arrange
var organization = CreateDeterministicOrganization();
organization.UseDisableSmAdsForUsers = useDisableSmAdsForUsers;
var subscriptionInfo = CreateDeterministicSubscriptionInfo();
var installationId = new Guid("78900000-0000-0000-0000-000000000123");
var mockLicensingService = CreateMockLicensingService();
var license = new OrganizationLicense(organization, subscriptionInfo, installationId, mockLicensingService);
license.Expires = DateTime.MaxValue; // Prevent expiration during test
var globalSettings = Substitute.For<IGlobalSettings>();
globalSettings.Installation.Returns(new GlobalSettings.InstallationSettings
{
Id = installationId
});
// Act & Assert - Verify VerifyData passes with the UseDisableSmAdsForUsers value
Assert.True(license.VerifyData(organization, claimsPrincipal, globalSettings));
}
}

View File

@@ -1,9 +1,14 @@
using System.Security.Claims;
using System.Reflection;
using System.Security.Claims;
using System.Text.RegularExpressions;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Services;
using Bit.Core.Settings;
@@ -88,7 +93,7 @@ public class UpdateOrganizationLicenseCommandTests
"Hash", "Signature", "SignatureBytes", "InstallationId", "Expires",
"ExpirationWithoutGracePeriod", "Token", "LimitCollectionCreationDeletion",
"LimitCollectionCreation", "LimitCollectionDeletion", "AllowAdminAccessToAllCollectionItems",
"UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation", "UsePhishingBlocker") &&
"UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation", "UsePhishingBlocker", "UseDisableSmAdsForUsers") &&
// Same property but different name, use explicit mapping
org.ExpirationDate == license.Expires));
}
@@ -99,6 +104,320 @@ public class UpdateOrganizationLicenseCommandTests
}
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_WithClaimsPrincipal_ExtractsAllPropertiesFromClaims(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup license for CanUse validation
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-1);
license.Expires = DateTime.Now.AddDays(1);
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.InstallationId = globalSettings.Installation.Id;
license.LicenseType = LicenseType.Organization;
license.Token = "test-token"; // Indicates this is a claims-based license
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
// Create a ClaimsPrincipal with all organization license claims
var claims = new List<Claim>
{
new(OrganizationLicenseConstants.LicenseType, ((int)LicenseType.Organization).ToString()),
new(OrganizationLicenseConstants.InstallationId, globalSettings.Installation.Id.ToString()),
new(OrganizationLicenseConstants.Name, "Test Organization"),
new(OrganizationLicenseConstants.BillingEmail, "billing@test.com"),
new(OrganizationLicenseConstants.BusinessName, "Test Business"),
new(OrganizationLicenseConstants.PlanType, ((int)PlanType.EnterpriseAnnually).ToString()),
new(OrganizationLicenseConstants.Seats, "100"),
new(OrganizationLicenseConstants.MaxCollections, "50"),
new(OrganizationLicenseConstants.UsePolicies, "true"),
new(OrganizationLicenseConstants.UseSso, "true"),
new(OrganizationLicenseConstants.UseKeyConnector, "true"),
new(OrganizationLicenseConstants.UseScim, "true"),
new(OrganizationLicenseConstants.UseGroups, "true"),
new(OrganizationLicenseConstants.UseDirectory, "true"),
new(OrganizationLicenseConstants.UseEvents, "true"),
new(OrganizationLicenseConstants.UseTotp, "true"),
new(OrganizationLicenseConstants.Use2fa, "true"),
new(OrganizationLicenseConstants.UseApi, "true"),
new(OrganizationLicenseConstants.UseResetPassword, "true"),
new(OrganizationLicenseConstants.Plan, "Enterprise"),
new(OrganizationLicenseConstants.SelfHost, "true"),
new(OrganizationLicenseConstants.UsersGetPremium, "true"),
new(OrganizationLicenseConstants.UseCustomPermissions, "true"),
new(OrganizationLicenseConstants.Enabled, "true"),
new(OrganizationLicenseConstants.Expires, DateTime.Now.AddDays(1).ToString("O")),
new(OrganizationLicenseConstants.LicenseKey, "test-license-key"),
new(OrganizationLicenseConstants.UsePasswordManager, "true"),
new(OrganizationLicenseConstants.UseSecretsManager, "true"),
new(OrganizationLicenseConstants.SmSeats, "25"),
new(OrganizationLicenseConstants.SmServiceAccounts, "10"),
new(OrganizationLicenseConstants.UseRiskInsights, "true"),
new(OrganizationLicenseConstants.UseOrganizationDomains, "true"),
new(OrganizationLicenseConstants.UseAdminSponsoredFamilies, "true"),
new(OrganizationLicenseConstants.UseAutomaticUserConfirmation, "true"),
new(OrganizationLicenseConstants.UseDisableSmAdsForUsers, "true"),
new(OrganizationLicenseConstants.UsePhishingBlocker, "true"),
new(OrganizationLicenseConstants.MaxStorageGb, "5"),
new(OrganizationLicenseConstants.Issued, DateTime.Now.AddDays(-1).ToString("O")),
new(OrganizationLicenseConstants.Refresh, DateTime.Now.AddMonths(1).ToString("O")),
new(OrganizationLicenseConstants.ExpirationWithoutGracePeriod, DateTime.Now.AddMonths(12).ToString("O")),
new(OrganizationLicenseConstants.Trial, "false"),
new(OrganizationLicenseConstants.LimitCollectionCreationDeletion, "true"),
new(OrganizationLicenseConstants.AllowAdminAccessToAllCollectionItems, "true")
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(claims));
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns(claimsPrincipal);
// Setup selfHostedOrg for CanUseLicense validation
selfHostedOrg.OccupiedSeatCount = 50; // Less than the 100 seats in the license
selfHostedOrg.CollectionCount = 10; // Less than the 50 max collections in the license
selfHostedOrg.GroupCount = 1;
selfHostedOrg.UseGroups = true;
selfHostedOrg.UsePolicies = true;
selfHostedOrg.UseSso = true;
selfHostedOrg.UseKeyConnector = true;
selfHostedOrg.UseScim = true;
selfHostedOrg.UseCustomPermissions = true;
selfHostedOrg.UseResetPassword = true;
try
{
await sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null);
// Assertion: license file should be written to disk
var filePath = Path.Combine(LicenseDirectory, "organization", $"{selfHostedOrg.Id}.json");
await using var fs = File.OpenRead(filePath);
var licenseFromFile = await JsonSerializer.DeserializeAsync<OrganizationLicense>(fs);
AssertHelper.AssertPropertyEqual(license, licenseFromFile, "SignatureBytes");
// Assertion: organization should be updated with ALL properties extracted from claims
await sutProvider.GetDependency<IOrganizationService>()
.Received(1)
.ReplaceAndUpdateCacheAsync(Arg.Is<Organization>(org =>
org.Name == "Test Organization" &&
org.BillingEmail == "billing@test.com" &&
org.BusinessName == "Test Business" &&
org.PlanType == PlanType.EnterpriseAnnually &&
org.Seats == 100 &&
org.MaxCollections == 50 &&
org.UsePolicies == true &&
org.UseSso == true &&
org.UseKeyConnector == true &&
org.UseScim == true &&
org.UseGroups == true &&
org.UseDirectory == true &&
org.UseEvents == true &&
org.UseTotp == true &&
org.Use2fa == true &&
org.UseApi == true &&
org.UseResetPassword == true &&
org.Plan == "Enterprise" &&
org.SelfHost == true &&
org.UsersGetPremium == true &&
org.UseCustomPermissions == true &&
org.Enabled == true &&
org.LicenseKey == "test-license-key" &&
org.UsePasswordManager == true &&
org.UseSecretsManager == true &&
org.SmSeats == 25 &&
org.SmServiceAccounts == 10 &&
org.UseRiskInsights == true &&
org.UseOrganizationDomains == true &&
org.UseAdminSponsoredFamilies == true &&
org.UseAutomaticUserConfirmation == true &&
org.UseDisableSmAdsForUsers == true &&
org.UsePhishingBlocker == true));
}
finally
{
// Clean up temporary directory
if (Directory.Exists(OrganizationLicenseDirectory.Value))
{
Directory.Delete(OrganizationLicenseDirectory.Value, true);
}
}
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_WrongInstallationIdInClaims_ThrowsBadRequestException(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup license for CanUse validation
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-1);
license.Expires = DateTime.Now.AddDays(1);
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.LicenseType = LicenseType.Organization;
license.Token = "test-token"; // Indicates this is a claims-based license
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
// Create a ClaimsPrincipal with WRONG installation ID
var wrongInstallationId = Guid.NewGuid(); // Different from globalSettings.Installation.Id
var claims = new List<Claim>
{
new(OrganizationLicenseConstants.LicenseType, ((int)LicenseType.Organization).ToString()),
new(OrganizationLicenseConstants.InstallationId, wrongInstallationId.ToString()),
new(OrganizationLicenseConstants.Enabled, "true"),
new(OrganizationLicenseConstants.SelfHost, "true")
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(claims));
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns(claimsPrincipal);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null));
Assert.Contains("The installation ID does not match the current installation.", exception.Message);
// Verify organization was NOT saved
await sutProvider.GetDependency<IOrganizationService>()
.DidNotReceive()
.ReplaceAndUpdateCacheAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_ExpiredLicenseWithoutClaims_ThrowsBadRequestException(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup legacy license (no Token, no claims)
license.Token = null; // Legacy license
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-2);
license.Expires = DateTime.Now.AddDays(-1); // Expired yesterday
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.InstallationId = globalSettings.Installation.Id;
license.LicenseType = LicenseType.Organization;
license.SelfHost = true;
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns((ClaimsPrincipal)null); // No claims for legacy license
// Passing values for SelfHostedOrganizationDetails.CanUseLicense
license.Seats = null;
license.MaxCollections = null;
license.UseGroups = true;
license.UsePolicies = true;
license.UseSso = true;
license.UseKeyConnector = true;
license.UseScim = true;
license.UseCustomPermissions = true;
license.UseResetPassword = true;
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null));
Assert.Contains("The license has expired.", exception.Message);
// Verify organization was NOT saved
await sutProvider.GetDependency<IOrganizationService>()
.DidNotReceive()
.ReplaceAndUpdateCacheAsync(Arg.Any<Organization>());
}
[Fact]
public async Task UpdateLicenseAsync_ExtractsAllClaimsBasedProperties_WhenClaimsPrincipalProvided()
{
// This test ensures that when new properties are added to OrganizationLicense,
// they are automatically extracted from JWT claims in UpdateOrganizationLicenseCommand.
// If a new constant is added to OrganizationLicenseConstants but not extracted,
// this test will fail with a clear message showing which properties are missing.
// 1. Get all OrganizationLicenseConstants
var constantFields = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.GetField)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToList();
// 2. Define properties that should be excluded (not claims-based or intentionally not extracted)
var excludedProperties = new HashSet<string>
{
"Version", // Not in claims system (only in deprecated property-based licenses)
"Hash", // Signature-related, not extracted from claims
"Signature", // Signature-related, not extracted from claims
"SignatureBytes", // Computed from Signature, not a claim
"Token", // The JWT itself, not extracted from claims
"Id" // Cloud org ID from license, not used - self-hosted org has its own separate ID
};
// 3. Get properties that should be extracted from claims
var propertiesThatShouldBeExtracted = constantFields
.Where(c => !excludedProperties.Contains(c))
.ToHashSet();
// 4. Read UpdateOrganizationLicenseCommand source code
var commandSourcePath = Path.Combine(
Directory.GetCurrentDirectory(),
"..", "..", "..", "..", "..",
"src", "Core", "Billing", "Organizations", "Commands", "UpdateOrganizationLicenseCommand.cs");
var sourceCode = await File.ReadAllTextAsync(commandSourcePath);
// 5. Find all GetValue calls that extract properties from claims
// Pattern matches: license.PropertyName = claimsPrincipal.GetValue<Type>(OrganizationLicenseConstants.PropertyName)
var extractedProperties = new HashSet<string>();
var getValuePattern = @"claimsPrincipal\.GetValue<[^>]+>\(OrganizationLicenseConstants\.(\w+)\)";
var matches = Regex.Matches(sourceCode, getValuePattern);
foreach (Match match in matches)
{
extractedProperties.Add(match.Groups[1].Value);
}
// 6. Find missing extractions
var missingExtractions = propertiesThatShouldBeExtracted
.Except(extractedProperties)
.OrderBy(p => p)
.ToList();
// 7. Build error message with guidance if there are missing extractions
var errorMessage = "";
if (missingExtractions.Any())
{
errorMessage = $"The following constants in OrganizationLicenseConstants are NOT extracted from claims in UpdateOrganizationLicenseCommand:\n";
errorMessage += string.Join("\n", missingExtractions.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following lines to UpdateOrganizationLicenseCommand.cs in the 'if (claimsPrincipal != null)' block:\n";
foreach (var prop in missingExtractions)
{
errorMessage += $" license.{prop} = claimsPrincipal.GetValue<TYPE>(OrganizationLicenseConstants.{prop});\n";
}
}
// 8. Assert - if this fails, the error message guides the developer to add the extraction
// Note: We don't check for "extra extractions" because that would be a compile error
// (can't reference OrganizationLicenseConstants.Foo if Foo doesn't exist)
Assert.True(
!missingExtractions.Any(),
$"\n{errorMessage}");
}
// Wrapper to compare 2 objects that are different types
private bool AssertPropertyEqual(OrganizationLicense expected, Organization actual, params string[] excludedPropertyStrings)
{

View File

@@ -8,6 +8,7 @@ using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
@@ -29,6 +30,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands;
public class CreatePremiumCloudHostedSubscriptionCommandTests
{
private readonly IBraintreeGateway _braintreeGateway = Substitute.For<IBraintreeGateway>();
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IGlobalSettings _globalSettings = Substitute.For<IGlobalSettings>();
private readonly ISetupIntentCache _setupIntentCache = Substitute.For<ISetupIntentCache>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
@@ -59,6 +61,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
_command = new CreatePremiumCloudHostedSubscriptionCommand(
_braintreeGateway,
_braintreeService,
_globalSettings,
_setupIntentCache,
_stripeAdapter,
@@ -235,11 +238,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.LatestInvoiceId = "in_123";
var mockInvoice = Substitute.For<Invoice>();
@@ -258,6 +265,12 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any<SubscriptionCreateOptions>());
await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand != null &&
opts.Expand.Contains("customer")));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
@@ -456,11 +469,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "incomplete";
mockSubscription.LatestInvoiceId = "in_123";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
@@ -487,6 +504,12 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
Assert.True(user.Premium);
Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand != null &&
opts.Expand.Contains("customer")));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
}
[Theory, BitAutoData]
@@ -559,11 +582,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active"; // PayPal + active doesn't match pattern
mockSubscription.LatestInvoiceId = "in_123";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
@@ -590,6 +617,12 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
Assert.False(user.Premium);
Assert.Null(user.PremiumExpirationDate);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts =>
opts.AutoAdvance == false &&
opts.Expand != null &&
opts.Expand.Contains("customer")));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
}
[Theory, BitAutoData]

View File

@@ -18,13 +18,11 @@ public class UpdatePremiumStorageCommandTests
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly PremiumPlan _premiumPlan;
private readonly UpdatePremiumStorageCommand _command;
public UpdatePremiumStorageCommandTests()
{
// Setup default premium plan with standard pricing
_premiumPlan = new PremiumPlan
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
@@ -32,7 +30,7 @@ public class UpdatePremiumStorageCommandTests
Seat = new PremiumPurchasable { Price = 10M, StripePriceId = "price_premium", Provided = 1 },
Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "price_storage", Provided = 1 }
};
_pricingClient.ListPremiumPlans().Returns(new List<PremiumPlan> { _premiumPlan });
_pricingClient.ListPremiumPlans().Returns([premiumPlan]);
_command = new UpdatePremiumStorageCommand(
_stripeAdapter,
@@ -43,18 +41,19 @@ public class UpdatePremiumStorageCommandTests
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null)
{
var items = new List<SubscriptionItem>();
// Always add the seat item
items.Add(new SubscriptionItem
var items = new List<SubscriptionItem>
{
Id = "si_seat",
Price = new Price { Id = "price_premium" },
Quantity = 1
});
// Always add the seat item
new()
{
Id = "si_seat",
Price = new Price { Id = "price_premium" },
Quantity = 1
}
};
// Add storage item if quantity is provided
if (storageQuantity.HasValue && storageQuantity.Value > 0)
if (storageQuantity is > 0)
{
items.Add(new SubscriptionItem
{
@@ -142,7 +141,7 @@ public class UpdatePremiumStorageCommandTests
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("No access to storage.", badRequest.Response);
Assert.Equal("User has no access to storage.", badRequest.Response);
}
[Theory, BitAutoData]
@@ -216,7 +215,7 @@ public class UpdatePremiumStorageCommandTests
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
opts.ProrationBehavior == "always_invoice"));
// Verify user was saved
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
@@ -233,7 +232,7 @@ public class UpdatePremiumStorageCommandTests
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", null);
var subscription = CreateMockSubscription("sub_123");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
// Act

View File

@@ -0,0 +1,646 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan;
using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class UpgradePremiumToOrganizationCommandTests
{
// Concrete test implementation of the abstract Plan record
private record TestPlan : Core.Models.StaticStore.Plan
{
public TestPlan(
PlanType planType,
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
{
Type = planType;
ProductTier = ProductTierType.Teams;
Name = "Test Plan";
IsAnnual = true;
NameLocalizationKey = "";
DescriptionLocalizationKey = "";
CanBeUsedByBusiness = true;
TrialPeriodDays = null;
HasSelfHost = false;
HasPolicies = false;
HasGroups = false;
HasDirectory = false;
HasEvents = false;
HasTotp = false;
Has2fa = false;
HasApi = false;
HasSso = false;
HasOrganizationDomains = false;
HasKeyConnector = false;
HasScim = false;
HasResetPassword = false;
UsersGetPremium = false;
HasCustomPermissions = false;
UpgradeSortOrder = 0;
DisplaySortOrder = 0;
LegacyYear = null;
Disabled = false;
PasswordManager = new PasswordManagerPlanFeatures
{
StripePlanId = stripePlanId,
StripeSeatPlanId = stripeSeatPlanId,
StripePremiumAccessPlanId = stripePremiumAccessPlanId,
StripeStoragePlanId = stripeStoragePlanId,
BasePrice = 0,
SeatPrice = 0,
ProviderPortalSeatPrice = 0,
AllowSeatAutoscale = true,
HasAdditionalSeatsOption = true,
BaseSeats = 1,
HasPremiumAccessOption = !string.IsNullOrEmpty(stripePremiumAccessPlanId),
PremiumAccessOptionPrice = 0,
MaxSeats = null,
BaseStorageGb = 1,
HasAdditionalStorageOption = !string.IsNullOrEmpty(stripeStoragePlanId),
AdditionalStoragePricePerGb = 0,
MaxCollections = null
};
SecretsManager = null;
}
}
private static Core.Models.StaticStore.Plan CreateTestPlan(
PlanType planType,
string? stripePlanId = null,
string? stripeSeatPlanId = null,
string? stripePremiumAccessPlanId = null,
string? stripeStoragePlanId = null)
{
return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId);
}
private static PremiumPlan CreateTestPremiumPlan(
string seatPriceId = "premium-annually",
string storagePriceId = "personal-storage-gb-annually",
bool available = true)
{
return new PremiumPlan
{
Name = "Premium",
LegacyYear = null,
Available = available,
Seat = new PremiumPurchasable
{
StripePriceId = seatPriceId,
Price = 10m,
Provided = 1
},
Storage = new PremiumPurchasable
{
StripePriceId = storagePriceId,
Price = 4m,
Provided = 1
}
};
}
private static List<PremiumPlan> CreateTestPremiumPlansList()
{
return new List<PremiumPlan>
{
// Current available plan
CreateTestPremiumPlan("premium-annually", "personal-storage-gb-annually", available: true),
// Legacy plan from 2020
CreateTestPremiumPlan("premium-annually-2020", "personal-storage-gb-annually-2020", available: false)
};
}
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IOrganizationUserRepository _organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository = Substitute.For<IOrganizationApiKeyRepository>();
private readonly IApplicationCacheService _applicationCacheService = Substitute.For<IApplicationCacheService>();
private readonly ILogger<UpgradePremiumToOrganizationCommand> _logger = Substitute.For<ILogger<UpgradePremiumToOrganizationCommand>>();
private readonly UpgradePremiumToOrganizationCommand _command;
public UpgradePremiumToOrganizationCommandTests()
{
_command = new UpgradePremiumToOrganizationCommand(
_logger,
_pricingClient,
_stripeAdapter,
_userService,
_organizationRepository,
_organizationUserRepository,
_organizationApiKeyRepository,
_applicationCacheService);
}
[Theory, BitAutoData]
public async Task Run_UserNotPremium_ReturnsBadRequest(User user)
{
// Arrange
user.Premium = false;
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("User does not have an active Premium subscription.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_UserNoGatewaySubscriptionId_ReturnsBadRequest(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = null;
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("User does not have an active Premium subscription.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_UserEmptyGatewaySubscriptionId_ReturnsBadRequest(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "";
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("User does not have an active Premium subscription.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_SuccessfulUpgrade_SeatBasedPlan_ReturnsSuccess(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
user.Id = Guid.NewGuid();
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually"
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted + 1 seat (no storage)
opts.Items.Any(i => i.Deleted == true) &&
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
await _organizationRepository.Received(1).CreateAsync(Arg.Is<Organization>(o =>
o.Name == "My Organization" &&
o.Gateway == GatewayType.Stripe &&
o.GatewaySubscriptionId == "sub_123" &&
o.GatewayCustomerId == "cus_123"));
await _organizationUserRepository.Received(1).CreateAsync(Arg.Is<OrganizationUser>(ou =>
ou.Key == "encrypted-key" &&
ou.Status == OrganizationUserStatusType.Confirmed));
await _organizationApiKeyRepository.Received(1).CreateAsync(Arg.Any<OrganizationApiKey>());
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
u.Premium == false &&
u.GatewaySubscriptionId == null &&
u.GatewayCustomerId == null));
}
[Theory, BitAutoData]
public async Task Run_SuccessfulUpgrade_NonSeatBasedPlan_ReturnsSuccess(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.FamiliesAnnually,
stripePlanId: "families-plan-annually",
stripeSeatPlanId: null // Non-seat-based
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted + 1 plan
opts.Items.Any(i => i.Deleted == true) &&
opts.Items.Any(i => i.Price == "families-plan-annually" && i.Quantity == 1)));
await _organizationRepository.Received(1).CreateAsync(Arg.Is<Organization>(o =>
o.Name == "My Families Org"));
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
u.Premium == false &&
u.GatewaySubscriptionId == null));
}
[Theory, BitAutoData]
public async Task Run_AddsMetadataWithOriginalPremiumPriceId(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1)
}
}
},
Metadata = new Dictionary<string, string>
{
["userId"] = user.Id.ToString()
}
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually"
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPremiumPriceId) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousPremiumPriceId] == "premium-annually" &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPeriodEndDate) &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "0" &&
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.UserId) &&
opts.Metadata[StripeConstants.MetadataKeys.UserId] == string.Empty)); // Removes userId to unlink from User
}
[Theory, BitAutoData]
public async Task Run_UserOnLegacyPremiumPlan_SuccessfullyDeletesLegacyItems(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium_legacy",
Price = new Price { Id = "premium-annually-2020" }, // Legacy price ID
CurrentPeriodEnd = currentPeriodEnd
},
new SubscriptionItem
{
Id = "si_storage_legacy",
Price = new Price { Id = "personal-storage-gb-annually-2020" }, // Legacy storage price ID
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually"
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT0);
// Verify that BOTH legacy items (password manager + storage) are deleted by ID
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 3 && // 2 deleted (legacy PM + legacy storage) + 1 new seat
opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium_legacy") == 1 && // Legacy PM deleted
opts.Items.Count(i => i.Deleted == true && i.Id == "si_storage_legacy") == 1 && // Legacy storage deleted
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
}
[Theory, BitAutoData]
public async Task Run_UserHasPremiumPlusOtherProducts_OnlyDeletesPremiumItems(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
},
new SubscriptionItem
{
Id = "si_other_product",
Price = new Price { Id = "some-other-product-id" }, // Non-premium item
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually"
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT0);
// Verify that ONLY the premium password manager item is deleted (not other products)
// Note: We delete the specific premium item by ID, so other products are untouched
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Items.Count == 2 && // 1 deleted (premium password manager) + 1 new seat
opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium") == 1 && // Premium item deleted by ID
opts.Items.Count(i => i.Id == "si_other_product") == 0 && // Other product NOT in update (untouched)
opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1)));
}
[Theory, BitAutoData]
public async Task Run_UserHasAdditionalStorage_CapturesStorageInMetadata(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
user.GatewayCustomerId = "cus_123";
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_premium",
Price = new Price { Id = "premium-annually" },
CurrentPeriodEnd = currentPeriodEnd
},
new SubscriptionItem
{
Id = "si_storage",
Price = new Price { Id = "personal-storage-gb-annually" },
Quantity = 5, // User has 5GB additional storage
CurrentPeriodEnd = currentPeriodEnd
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
var mockPlan = CreateTestPlan(
PlanType.TeamsAnnually,
stripeSeatPlanId: "teams-seat-annually"
);
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
_pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan);
_stripeAdapter.UpdateSubscriptionAsync(Arg.Any<string>(), Arg.Any<SubscriptionUpdateOptions>())
.Returns(Task.FromResult(mockSubscription));
_organizationRepository.CreateAsync(Arg.Any<Organization>()).Returns(callInfo => Task.FromResult(callInfo.Arg<Organization>()));
_organizationApiKeyRepository.CreateAsync(Arg.Any<OrganizationApiKey>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationApiKey>()));
_organizationUserRepository.CreateAsync(Arg.Any<OrganizationUser>()).Returns(callInfo => Task.FromResult(callInfo.Arg<OrganizationUser>()));
_applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any<Organization>()).Returns(Task.CompletedTask);
_userService.SaveUserAsync(user).Returns(Task.CompletedTask);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT0);
// Verify that the additional storage quantity (5) is captured in metadata
await _stripeAdapter.Received(1).UpdateSubscriptionAsync(
"sub_123",
Arg.Is<SubscriptionUpdateOptions>(opts =>
opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) &&
opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "5" &&
opts.Items.Count == 3 && // 2 deleted (premium + storage) + 1 new seat
opts.Items.Count(i => i.Deleted == true) == 2));
}
[Theory, BitAutoData]
public async Task Run_NoPremiumSubscriptionItemFound_ReturnsBadRequest(User user)
{
// Arrange
user.Premium = true;
user.GatewaySubscriptionId = "sub_123";
var mockSubscription = new Subscription
{
Id = "sub_123",
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new SubscriptionItem
{
Id = "si_other",
Price = new Price { Id = "some-other-product" }, // Not a premium plan
CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1)
}
}
},
Metadata = new Dictionary<string, string>()
};
var mockPremiumPlans = CreateTestPremiumPlansList();
_stripeAdapter.GetSubscriptionAsync("sub_123")
.Returns(mockSubscription);
_pricingClient.ListPremiumPlans().Returns(mockPremiumPlans);
// Act
var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Premium subscription item not found.", badRequest.Response);
}
}

View File

@@ -0,0 +1,607 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Billing.Subscriptions.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Subscriptions.Queries;
using static StripeConstants;
public class GetBitwardenSubscriptionQueryTests
{
private readonly ILogger<GetBitwardenSubscriptionQuery> _logger = Substitute.For<ILogger<GetBitwardenSubscriptionQuery>>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly GetBitwardenSubscriptionQuery _query;
public GetBitwardenSubscriptionQueryTests()
{
_query = new GetBitwardenSubscriptionQuery(
_logger,
_pricingClient,
_stripeAdapter);
}
[Fact]
public async Task Run_IncompleteStatus_ReturnsBitwardenSubscriptionWithSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Incomplete);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Incomplete, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(subscription.Created.AddHours(23), result.Suspension);
Assert.Equal(1, result.GracePeriod);
Assert.Null(result.NextCharge);
Assert.Null(result.CancelAt);
}
[Fact]
public async Task Run_IncompleteExpiredStatus_ReturnsBitwardenSubscriptionWithSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.IncompleteExpired);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.IncompleteExpired, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(subscription.Created.AddHours(23), result.Suspension);
Assert.Equal(1, result.GracePeriod);
}
[Fact]
public async Task Run_TrialingStatus_ReturnsBitwardenSubscriptionWithNextCharge()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Trialing);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Trialing, result.Status);
Assert.NotNull(result.NextCharge);
Assert.Equal(subscription.Items.First().CurrentPeriodEnd, result.NextCharge);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_ActiveStatus_ReturnsBitwardenSubscriptionWithNextCharge()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Active, result.Status);
Assert.NotNull(result.NextCharge);
Assert.Equal(subscription.Items.First().CurrentPeriodEnd, result.NextCharge);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_ActiveStatusWithCancelAt_ReturnsCancelAt()
{
var user = CreateUser();
var cancelAt = DateTime.UtcNow.AddMonths(1);
var subscription = CreateSubscription(SubscriptionStatus.Active, cancelAt: cancelAt);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Active, result.Status);
Assert.Equal(cancelAt, result.CancelAt);
}
[Fact]
public async Task Run_PastDueStatus_WithOpenInvoices_ReturnsSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.PastDue, collectionMethod: "charge_automatically");
var premiumPlans = CreatePremiumPlans();
var openInvoice = CreateInvoice();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([openInvoice]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.PastDue, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(openInvoice.Created.AddDays(14), result.Suspension);
Assert.Equal(14, result.GracePeriod);
}
[Fact]
public async Task Run_PastDueStatus_WithoutOpenInvoices_ReturnsNoSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.PastDue);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.PastDue, result.Status);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_UnpaidStatus_WithOpenInvoices_ReturnsSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Unpaid, collectionMethod: "charge_automatically");
var premiumPlans = CreatePremiumPlans();
var openInvoice = CreateInvoice();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([openInvoice]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Unpaid, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(14, result.GracePeriod);
}
[Fact]
public async Task Run_CanceledStatus_ReturnsCanceledDate()
{
var user = CreateUser();
var canceledAt = DateTime.UtcNow.AddDays(-5);
var subscription = CreateSubscription(SubscriptionStatus.Canceled, canceledAt: canceledAt);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Canceled, result.Status);
Assert.Equal(canceledAt, result.Canceled);
Assert.Null(result.Suspension);
Assert.Null(result.NextCharge);
}
[Fact]
public async Task Run_UnmanagedStatus_ThrowsConflictException()
{
var user = CreateUser();
var subscription = CreateSubscription("unmanaged_status");
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
await Assert.ThrowsAsync<ConflictException>(() => _query.Run(user));
}
[Fact]
public async Task Run_WithAdditionalStorage_IncludesStorageInCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: true);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.PasswordManager.AdditionalStorage);
Assert.Equal("additionalStorageGB", result.Cart.PasswordManager.AdditionalStorage.TranslationKey);
Assert.Equal(2, result.Cart.PasswordManager.AdditionalStorage.Quantity);
Assert.NotNull(result.Storage);
}
[Fact]
public async Task Run_WithoutAdditionalStorage_ExcludesStorageFromCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: false);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Null(result.Cart.PasswordManager.AdditionalStorage);
Assert.NotNull(result.Storage);
}
[Fact]
public async Task Run_WithCartLevelDiscount_IncludesDiscountInCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
subscription.Customer.Discount = CreateDiscount(discountType: "cart");
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.Discount);
Assert.Equal(BitwardenDiscountType.PercentOff, result.Cart.Discount.Type);
Assert.Equal(20, result.Cart.Discount.Value);
}
[Fact]
public async Task Run_WithProductLevelDiscount_IncludesDiscountInCartItem()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var productDiscount = CreateDiscount(discountType: "product", productId: "prod_premium_seat");
subscription.Discounts = [productDiscount];
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.PasswordManager.Seats.Discount);
Assert.Equal(BitwardenDiscountType.PercentOff, result.Cart.PasswordManager.Seats.Discount.Type);
}
[Fact]
public async Task Run_WithoutMaxStorageGb_ReturnsNullStorage()
{
var user = CreateUser();
user.MaxStorageGb = null;
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Null(result.Storage);
}
[Fact]
public async Task Run_CalculatesStorageCorrectly()
{
var user = CreateUser();
user.Storage = 5368709120; // 5 GB in bytes
user.MaxStorageGb = 10;
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: true);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Storage);
Assert.Equal(10, result.Storage.Available);
Assert.Equal(5.0, result.Storage.Used);
Assert.NotEmpty(result.Storage.ReadableUsed);
}
[Fact]
public async Task Run_TaxEstimation_WithInvoiceUpcomingNoneError_ReturnsZeroTax()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.ThrowsAsync(new StripeException { StripeError = new StripeError { Code = ErrorCodes.InvoiceUpcomingNone } });
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(0, result.Cart.EstimatedTax);
}
[Fact]
public async Task Run_MissingPasswordManagerSeatsItem_ThrowsConflictException()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
subscription.Items = new StripeList<SubscriptionItem>
{
Data = []
};
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
await Assert.ThrowsAsync<ConflictException>(() => _query.Run(user));
}
[Fact]
public async Task Run_IncludesEstimatedTax()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
var invoice = CreateInvoicePreview(totalTax: 500); // $5.00 tax
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(5.0m, result.Cart.EstimatedTax);
}
[Fact]
public async Task Run_SetsCadenceToAnnually()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(PlanCadenceType.Annually, result.Cart.Cadence);
}
#region Helper Methods
private static User CreateUser()
{
return new User
{
Id = Guid.NewGuid(),
GatewaySubscriptionId = "sub_test123",
MaxStorageGb = 1,
Storage = 1073741824 // 1 GB in bytes
};
}
private static Subscription CreateSubscription(
string status,
bool includeStorage = false,
DateTime? cancelAt = null,
DateTime? canceledAt = null,
string collectionMethod = "charge_automatically")
{
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var items = new List<SubscriptionItem>
{
new()
{
Id = "si_premium_seat",
Price = new Price
{
Id = "price_premium_seat",
UnitAmountDecimal = 1000,
Product = new Product { Id = "prod_premium_seat" }
},
Quantity = 1,
CurrentPeriodStart = DateTime.UtcNow,
CurrentPeriodEnd = currentPeriodEnd
}
};
if (includeStorage)
{
items.Add(new SubscriptionItem
{
Id = "si_storage",
Price = new Price
{
Id = "price_storage",
UnitAmountDecimal = 400,
Product = new Product { Id = "prod_storage" }
},
Quantity = 2,
CurrentPeriodStart = DateTime.UtcNow,
CurrentPeriodEnd = currentPeriodEnd
});
}
return new Subscription
{
Id = "sub_test123",
Status = status,
Created = DateTime.UtcNow.AddMonths(-1),
Customer = new Customer
{
Id = "cus_test123",
Discount = null
},
Items = new StripeList<SubscriptionItem>
{
Data = items
},
CancelAt = cancelAt,
CanceledAt = canceledAt,
CollectionMethod = collectionMethod,
Discounts = []
};
}
private static List<Bit.Core.Billing.Pricing.Premium.Plan> CreatePremiumPlans()
{
return
[
new()
{
Name = "Premium",
Available = true,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_premium_seat",
Price = 10.0m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_storage",
Price = 4.0m,
Provided = 1
}
}
];
}
private static Invoice CreateInvoice()
{
return new Invoice
{
Id = "in_test123",
Created = DateTime.UtcNow.AddDays(-10),
PeriodEnd = DateTime.UtcNow.AddDays(-5),
Attempted = true,
Status = "open"
};
}
private static Invoice CreateInvoicePreview(long totalTax = 0)
{
var taxes = totalTax > 0
? new List<InvoiceTotalTax> { new() { Amount = totalTax } }
: new List<InvoiceTotalTax>();
return new Invoice
{
Id = "in_preview",
TotalTaxes = taxes
};
}
private static Discount CreateDiscount(string discountType = "cart", string? productId = null)
{
var coupon = new Coupon
{
Valid = true,
PercentOff = 20,
AppliesTo = discountType == "product" && productId != null
? new CouponAppliesTo { Products = [productId] }
: new CouponAppliesTo { Products = [] }
};
return new Discount
{
Coupon = coupon
};
}
#endregion
}

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<RootNamespace>Bit.Core.Test</RootNamespace>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
@@ -30,7 +32,7 @@
<ItemGroup>
<!-- Email templates uses .hbs extension, they must be included for emails to work -->
<EmbeddedResource Include="**\*.hbs" />
<EmbeddedResource Include="Utilities\data\embeddedResource.txt" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,211 @@
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class PlayIdServiceTests
{
[Theory]
[BitAutoData]
public void InPlay_WhenPlayIdSetAndDevelopment_ReturnsTrue(
string playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.True(result);
Assert.Equal(playId, resultPlayId);
}
[Theory]
[BitAutoData]
public void InPlay_WhenPlayIdSetButNotDevelopment_ReturnsFalse(
string playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Production);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.False(result);
Assert.Equal(playId, resultPlayId);
}
[Theory]
[BitAutoData((string?)null)]
[BitAutoData("")]
public void InPlay_WhenPlayIdNullOrEmptyAndDevelopment_ReturnsFalse(
string? playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.False(result);
Assert.Empty(resultPlayId);
}
[Theory]
[BitAutoData]
public void PlayId_CanGetAndSet(string playId)
{
var hostEnvironment = Substitute.For<IHostEnvironment>();
var sut = new PlayIdService(hostEnvironment);
sut.PlayId = playId;
Assert.Equal(playId, sut.PlayId);
}
}
[SutProviderCustomize]
public class NeverPlayIdServicesTests
{
[Fact]
public void InPlay_ReturnsFalse()
{
var sut = new NeverPlayIdServices();
var result = sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[InlineData("test-play-id")]
[InlineData(null)]
public void PlayId_SetterDoesNothing_GetterReturnsNull(string? value)
{
var sut = new NeverPlayIdServices();
sut.PlayId = value;
Assert.Null(sut.PlayId);
}
}
[SutProviderCustomize]
public class PlayIdSingletonServiceTests
{
public static IEnumerable<object[]> SutProvider()
{
var sutProvider = new SutProvider<PlayIdSingletonService>();
var httpContext = sutProvider.CreateDependency<HttpContext>();
var serviceProvider = sutProvider.CreateDependency<IServiceProvider>();
var hostEnvironment = sutProvider.CreateDependency<IHostEnvironment>();
var playIdService = new PlayIdService(hostEnvironment);
sutProvider.SetDependency(playIdService);
httpContext.RequestServices.Returns(serviceProvider);
serviceProvider.GetService<PlayIdService>().Returns(playIdService);
serviceProvider.GetRequiredService<PlayIdService>().Returns(playIdService);
sutProvider.CreateDependency<IHttpContextAccessor>().HttpContext.Returns(httpContext);
sutProvider.Create();
return [[sutProvider]];
}
private void PrepHttpContext(
SutProvider<PlayIdSingletonService> sutProvider)
{
var httpContext = sutProvider.CreateDependency<HttpContext>();
var serviceProvider = sutProvider.CreateDependency<IServiceProvider>();
var PlayIdService = sutProvider.CreateDependency<PlayIdService>();
httpContext.RequestServices.Returns(serviceProvider);
serviceProvider.GetRequiredService<PlayIdService>().Returns(PlayIdService);
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns(httpContext);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenNoHttpContext_ReturnsFalse(
SutProvider<PlayIdSingletonService> sutProvider)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenNotDevelopment_ReturnsFalse(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
var scopedPlayIdService = sutProvider.GetDependency<PlayIdService>();
scopedPlayIdService.PlayId = playIdValue;
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Production);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenDevelopmentAndHttpContextWithPlayId_ReturnsTrue(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
sutProvider.GetDependency<PlayIdService>().PlayId = playIdValue;
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.True(result);
Assert.Equal(playIdValue, playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_SetterSetsOnScopedService(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
var scopedPlayIdService = sutProvider.GetDependency<PlayIdService>();
sutProvider.Sut.PlayId = playIdValue;
Assert.Equal(playIdValue, scopedPlayIdService.PlayId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_WhenNoHttpContext_GetterReturnsNull(
SutProvider<PlayIdSingletonService> sutProvider)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
var result = sutProvider.Sut.PlayId;
Assert.Null(result);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_WhenNoHttpContext_SetterDoesNotThrow(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
sutProvider.Sut.PlayId = playIdValue;
}
}

View File

@@ -0,0 +1,143 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class PlayItemServiceTests
{
[Theory]
[BitAutoData]
public async Task Record_User_WhenInPlay_RecordsPlayItem(
string playId,
User user,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = playId;
return true;
});
await sutProvider.Sut.Record(user);
await sutProvider.GetDependency<IPlayItemRepository>()
.Received(1)
.CreateAsync(Arg.Is<PlayItem>(pd =>
pd.PlayId == playId &&
pd.UserId == user.Id &&
pd.OrganizationId == null));
sutProvider.GetDependency<ILogger<PlayItemService>>()
.Received(1)
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o => o.ToString().Contains(user.Id.ToString()) && o.ToString().Contains(playId)),
null,
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_User_WhenNotInPlay_DoesNotRecordPlayItem(
User user,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = null;
return false;
});
await sutProvider.Sut.Record(user);
await sutProvider.GetDependency<IPlayItemRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<PlayItem>());
sutProvider.GetDependency<ILogger<PlayItemService>>()
.DidNotReceive()
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_Organization_WhenInPlay_RecordsPlayItem(
string playId,
Organization organization,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = playId;
return true;
});
await sutProvider.Sut.Record(organization);
await sutProvider.GetDependency<IPlayItemRepository>()
.Received(1)
.CreateAsync(Arg.Is<PlayItem>(pd =>
pd.PlayId == playId &&
pd.OrganizationId == organization.Id &&
pd.UserId == null));
sutProvider.GetDependency<ILogger<PlayItemService>>()
.Received(1)
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o => o.ToString().Contains(organization.Id.ToString()) && o.ToString().Contains(playId)),
null,
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_Organization_WhenNotInPlay_DoesNotRecordPlayItem(
Organization organization,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = null;
return false;
});
await sutProvider.Sut.Record(organization);
await sutProvider.GetDependency<IPlayItemRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<PlayItem>());
sutProvider.GetDependency<ILogger<PlayItemService>>()
.DidNotReceive()
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception?, string>>());
}
}

View File

@@ -1,4 +1,5 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Queries;
@@ -47,7 +48,7 @@ public class SendAuthenticationQueryTests
{
// Arrange
var sendId = Guid.NewGuid();
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: emailString, password: null);
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: emailString, password: null, AuthType.Email);
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
@@ -63,7 +64,7 @@ public class SendAuthenticationQueryTests
{
// Arrange
var sendId = Guid.NewGuid();
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: "hashedpassword");
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: "hashedpassword", AuthType.Email);
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
@@ -78,7 +79,7 @@ public class SendAuthenticationQueryTests
{
// Arrange
var sendId = Guid.NewGuid();
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null);
var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null, AuthType.None);
_sendRepository.GetByIdAsync(sendId).Returns(send);
// Act
@@ -105,11 +106,11 @@ public class SendAuthenticationQueryTests
public static IEnumerable<object[]> AuthenticationMethodTestCases()
{
yield return new object[] { null, typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 5, maxAccessCount: 5, emails: null, password: null), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 6, maxAccessCount: 5, emails: null, password: null), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: null), typeof(EmailOtp) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: "hashedpassword"), typeof(ResourcePassword) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null), typeof(NotAuthenticated) };
yield return new object[] { CreateSend(accessCount: 5, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 6, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: null, AuthType.Email), typeof(EmailOtp) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: "hashedpassword", AuthType.Password), typeof(ResourcePassword) };
yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null, AuthType.None), typeof(NotAuthenticated) };
}
public static IEnumerable<object[]> EmailParsingTestCases()
@@ -121,7 +122,7 @@ public class SendAuthenticationQueryTests
yield return new object[] { " , test@example.com, ,other@example.com, ", new[] { "test@example.com", "other@example.com" } };
}
private static Send CreateSend(int accessCount, int? maxAccessCount, string? emails, string? password)
private static Send CreateSend(int accessCount, int? maxAccessCount, string? emails, string? password, AuthType? authType)
{
return new Send
{
@@ -129,7 +130,8 @@ public class SendAuthenticationQueryTests
AccessCount = accessCount,
MaxAccessCount = maxAccessCount,
Emails = emails,
Password = password
Password = password,
AuthType = authType
};
}
}

View File

@@ -12,7 +12,6 @@ namespace Bit.Core.Test.Tools.Services;
public class SendOwnerQueryTests
{
private readonly ISendRepository _sendRepository;
private readonly IFeatureService _featureService;
private readonly IUserService _userService;
private readonly SendOwnerQuery _sendOwnerQuery;
private readonly Guid _currentUserId = Guid.NewGuid();
@@ -21,11 +20,10 @@ public class SendOwnerQueryTests
public SendOwnerQueryTests()
{
_sendRepository = Substitute.For<ISendRepository>();
_featureService = Substitute.For<IFeatureService>();
_userService = Substitute.For<IUserService>();
_user = new ClaimsPrincipal();
_userService.GetProperUserId(_user).Returns(_currentUserId);
_sendOwnerQuery = new SendOwnerQuery(_sendRepository, _featureService, _userService);
_sendOwnerQuery = new SendOwnerQuery(_sendRepository, _userService);
}
[Fact]
@@ -84,7 +82,7 @@ public class SendOwnerQueryTests
}
[Fact]
public async Task GetOwned_WithFeatureFlagEnabled_ReturnsAllSends()
public async Task GetOwned_ReturnsAllSendsIncludingEmailOTP()
{
// Arrange
var sends = new List<Send>
@@ -94,7 +92,6 @@ public class SendOwnerQueryTests
CreateSend(Guid.NewGuid(), _currentUserId, emails: "other@example.com")
};
_sendRepository.GetManyByUserIdAsync(_currentUserId).Returns(sends);
_featureService.IsEnabled(FeatureFlagKeys.PM19051_ListEmailOtpSends).Returns(true);
// Act
var result = await _sendOwnerQuery.GetOwned(_user);
@@ -105,28 +102,6 @@ public class SendOwnerQueryTests
Assert.Contains(sends[1], result);
Assert.Contains(sends[2], result);
await _sendRepository.Received(1).GetManyByUserIdAsync(_currentUserId);
_featureService.Received(1).IsEnabled(FeatureFlagKeys.PM19051_ListEmailOtpSends);
}
[Fact]
public async Task GetOwned_WithFeatureFlagDisabled_FiltersOutEmailOtpSends()
{
// Arrange
var sendWithoutEmails = CreateSend(Guid.NewGuid(), _currentUserId, emails: null);
var sendWithEmails = CreateSend(Guid.NewGuid(), _currentUserId, emails: "test@example.com");
var sends = new List<Send> { sendWithoutEmails, sendWithEmails };
_sendRepository.GetManyByUserIdAsync(_currentUserId).Returns(sends);
_featureService.IsEnabled(FeatureFlagKeys.PM19051_ListEmailOtpSends).Returns(false);
// Act
var result = await _sendOwnerQuery.GetOwned(_user);
// Assert
Assert.Single(result);
Assert.Contains(sendWithoutEmails, result);
Assert.DoesNotContain(sendWithEmails, result);
await _sendRepository.Received(1).GetManyByUserIdAsync(_currentUserId);
_featureService.Received(1).IsEnabled(FeatureFlagKeys.PM19051_ListEmailOtpSends);
}
[Fact]
@@ -147,7 +122,6 @@ public class SendOwnerQueryTests
// Arrange
var emptySends = new List<Send>();
_sendRepository.GetManyByUserIdAsync(_currentUserId).Returns(emptySends);
_featureService.IsEnabled(FeatureFlagKeys.PM19051_ListEmailOtpSends).Returns(true);
// Act
var result = await _sendOwnerQuery.GetOwned(_user);

View File

@@ -0,0 +1,219 @@
using System.Runtime.Serialization;
using System.Text.Json;
using System.Text.Json.Serialization;
using Bit.Core.Utilities;
using Xunit;
namespace Bit.Core.Test.Utilities;
public class EnumMemberJsonConverterTests
{
[Fact]
public void Serialize_WithEnumMemberAttribute_UsesAttributeValue()
{
// Arrange
var obj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.InProgress
};
const string expectedJsonString = "{\"Status\":\"in_progress\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Serialize_WithoutEnumMemberAttribute_UsesEnumName()
{
// Arrange
var obj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Pending
};
const string expectedJsonString = "{\"Status\":\"Pending\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Serialize_MultipleValues_SerializesCorrectly()
{
// Arrange
var obj = new EnumConverterTestObjectWithMultiple
{
Status1 = EnumConverterTestStatus.Active,
Status2 = EnumConverterTestStatus.InProgress,
Status3 = EnumConverterTestStatus.Pending
};
const string expectedJsonString = "{\"Status1\":\"active\",\"Status2\":\"in_progress\",\"Status3\":\"Pending\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Deserialize_WithEnumMemberAttribute_ReturnsCorrectEnumValue()
{
// Arrange
const string json = "{\"Status\":\"in_progress\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.InProgress, obj.Status);
}
[Fact]
public void Deserialize_WithoutEnumMemberAttribute_ReturnsCorrectEnumValue()
{
// Arrange
const string json = "{\"Status\":\"Pending\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.Pending, obj.Status);
}
[Fact]
public void Deserialize_MultipleValues_DeserializesCorrectly()
{
// Arrange
const string json = "{\"Status1\":\"active\",\"Status2\":\"in_progress\",\"Status3\":\"Pending\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObjectWithMultiple>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.Active, obj.Status1);
Assert.Equal(EnumConverterTestStatus.InProgress, obj.Status2);
Assert.Equal(EnumConverterTestStatus.Pending, obj.Status3);
}
[Fact]
public void Deserialize_InvalidEnumString_ThrowsJsonException()
{
// Arrange
const string json = "{\"Status\":\"invalid_value\"}";
// Act & Assert
var exception = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<EnumConverterTestObject>(json));
Assert.Contains("Unable to convert 'invalid_value' to EnumConverterTestStatus", exception.Message);
}
[Fact]
public void Deserialize_EmptyString_ThrowsJsonException()
{
// Arrange
const string json = "{\"Status\":\"\"}";
// Act & Assert
var exception = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<EnumConverterTestObject>(json));
Assert.Contains("Unable to convert '' to EnumConverterTestStatus", exception.Message);
}
[Fact]
public void RoundTrip_WithEnumMemberAttribute_PreservesValue()
{
// Arrange
var originalObj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Completed
};
// Act
var json = JsonSerializer.Serialize(originalObj);
var deserializedObj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(originalObj.Status, deserializedObj.Status);
}
[Fact]
public void RoundTrip_WithoutEnumMemberAttribute_PreservesValue()
{
// Arrange
var originalObj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Pending
};
// Act
var json = JsonSerializer.Serialize(originalObj);
var deserializedObj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(originalObj.Status, deserializedObj.Status);
}
[Fact]
public void Serialize_AllEnumValues_ProducesExpectedStrings()
{
// Arrange & Act & Assert
Assert.Equal("\"Pending\"", JsonSerializer.Serialize(EnumConverterTestStatus.Pending, CreateOptions()));
Assert.Equal("\"active\"", JsonSerializer.Serialize(EnumConverterTestStatus.Active, CreateOptions()));
Assert.Equal("\"in_progress\"", JsonSerializer.Serialize(EnumConverterTestStatus.InProgress, CreateOptions()));
Assert.Equal("\"completed\"", JsonSerializer.Serialize(EnumConverterTestStatus.Completed, CreateOptions()));
}
[Fact]
public void Deserialize_AllEnumValues_ReturnsCorrectEnums()
{
// Arrange & Act & Assert
Assert.Equal(EnumConverterTestStatus.Pending, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"Pending\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.Active, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"active\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.InProgress, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"in_progress\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.Completed, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"completed\"", CreateOptions()));
}
private static JsonSerializerOptions CreateOptions()
{
var options = new JsonSerializerOptions();
options.Converters.Add(new EnumMemberJsonConverter<EnumConverterTestStatus>());
return options;
}
}
public class EnumConverterTestObject
{
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status { get; set; }
}
public class EnumConverterTestObjectWithMultiple
{
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status1 { get; set; }
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status2 { get; set; }
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status3 { get; set; }
}
public enum EnumConverterTestStatus
{
Pending, // No EnumMemberAttribute
[EnumMember(Value = "active")]
Active,
[EnumMember(Value = "in_progress")]
InProgress,
[EnumMember(Value = "completed")]
Completed
}

View File

@@ -12,7 +12,6 @@ internal class OrganizationCipher : ICustomization
{
fixture.Customize<Cipher>(composer => composer
.With(c => c.OrganizationId, OrganizationId ?? Guid.NewGuid())
.Without(c => c.ArchivedDate)
.Without(c => c.UserId));
fixture.Customize<CipherDetails>(composer => composer
.With(c => c.OrganizationId, Guid.NewGuid())
@@ -28,7 +27,6 @@ internal class UserCipher : ICustomization
{
fixture.Customize<Cipher>(composer => composer
.With(c => c.UserId, UserId ?? Guid.NewGuid())
.Without(c => c.ArchivedDate)
.Without(c => c.OrganizationId));
fixture.Customize<CipherDetails>(composer => composer
.With(c => c.UserId, Guid.NewGuid())

View File

@@ -16,16 +16,15 @@ namespace Bit.Core.Test.Vault.Commands;
public class ArchiveCiphersCommandTest
{
[Theory]
[BitAutoData(true, false, 1, 1, 1)]
[BitAutoData(false, false, 1, 0, 1)]
[BitAutoData(false, true, 1, 0, 1)]
[BitAutoData(true, true, 1, 0, 1)]
public async Task ArchiveAsync_Works(
bool isEditable, bool hasOrganizationId,
[BitAutoData(true, 1, 1, 1)]
[BitAutoData(false, 1, 0, 1)]
[BitAutoData(false, 1, 0, 1)]
[BitAutoData(true, 1, 0, 1)]
public async Task ArchiveManyAsync_Works(
bool hasOrganizationId,
int cipherRepoCalls, int resultCountFromQuery, int pushNotificationsCalls,
SutProvider<ArchiveCiphersCommand> sutProvider, CipherDetails cipher, User user)
{
cipher.Edit = isEditable;
cipher.OrganizationId = hasOrganizationId ? Guid.NewGuid() : null;
var cipherList = new List<CipherDetails> { cipher };
@@ -46,4 +45,33 @@ public class ArchiveCiphersCommandTest
await sutProvider.GetDependency<IPushNotificationService>().Received(pushNotificationsCalls)
.PushSyncCiphersAsync(user.Id);
}
[Theory]
[BitAutoData]
public async Task ArchiveManyAsync_SetsArchivedDateOnReturnedCiphers(
SutProvider<ArchiveCiphersCommand> sutProvider,
CipherDetails cipher,
User user)
{
// Allow organization cipher to be archived in this test
cipher.OrganizationId = Guid.Parse("3f2504e0-4f89-11d3-9a0c-0305e82c3301");
sutProvider.GetDependency<ICipherRepository>()
.GetManyByUserIdAsync(user.Id)
.Returns(new List<CipherDetails> { cipher });
var repoRevisionDate = DateTime.UtcNow;
sutProvider.GetDependency<ICipherRepository>()
.ArchiveAsync(Arg.Any<IEnumerable<Guid>>(), user.Id)
.Returns(repoRevisionDate);
// Act
var result = await sutProvider.Sut.ArchiveManyAsync(new[] { cipher.Id }, user.Id);
// Assert
var archivedCipher = Assert.Single(result);
Assert.Equal(repoRevisionDate, archivedCipher.RevisionDate);
Assert.Equal(repoRevisionDate, archivedCipher.ArchivedDate);
}
}

View File

@@ -16,16 +16,15 @@ namespace Bit.Core.Test.Vault.Commands;
public class UnarchiveCiphersCommandTest
{
[Theory]
[BitAutoData(true, false, 1, 1, 1)]
[BitAutoData(false, false, 1, 0, 1)]
[BitAutoData(false, true, 1, 0, 1)]
[BitAutoData(true, true, 1, 1, 1)]
[BitAutoData(true, 1, 1, 1)]
[BitAutoData(false, 1, 0, 1)]
[BitAutoData(false, 1, 0, 1)]
[BitAutoData(true, 1, 1, 1)]
public async Task UnarchiveAsync_Works(
bool isEditable, bool hasOrganizationId,
bool hasOrganizationId,
int cipherRepoCalls, int resultCountFromQuery, int pushNotificationsCalls,
SutProvider<UnarchiveCiphersCommand> sutProvider, CipherDetails cipher, User user)
{
cipher.Edit = isEditable;
cipher.OrganizationId = hasOrganizationId ? Guid.NewGuid() : null;
var cipherList = new List<CipherDetails> { cipher };
@@ -46,4 +45,33 @@ public class UnarchiveCiphersCommandTest
await sutProvider.GetDependency<IPushNotificationService>().Received(pushNotificationsCalls)
.PushSyncCiphersAsync(user.Id);
}
[Theory]
[BitAutoData]
public async Task UnarchiveAsync_ClearsArchivedDateOnReturnedCiphers(
SutProvider<UnarchiveCiphersCommand> sutProvider,
CipherDetails cipher,
User user)
{
cipher.OrganizationId = null;
cipher.ArchivedDate = DateTime.UtcNow;
sutProvider.GetDependency<ICipherRepository>()
.GetManyByUserIdAsync(user.Id)
.Returns(new List<CipherDetails> { cipher });
var repoRevisionDate = DateTime.UtcNow.AddMinutes(1);
sutProvider.GetDependency<ICipherRepository>()
.UnarchiveAsync(Arg.Any<IEnumerable<Guid>>(), user.Id)
.Returns(repoRevisionDate);
// Act
var result = await sutProvider.Sut.UnarchiveManyAsync(new[] { cipher.Id }, user.Id);
// Assert
var unarchivedCipher = Assert.Single(result);
Assert.Equal(repoRevisionDate, unarchivedCipher.RevisionDate);
Assert.Null(unarchivedCipher.ArchivedDate);
}
}

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<PropertyGroup Condition=" '$(RunConfiguration)' == 'Identity.IntegrationTest' " />

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,5 +1,4 @@
using Bit.Core;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Context;
@@ -7,7 +6,6 @@ using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Identity.IdentityServer;
using Bit.Identity.Test.AutoFixture;
using Bit.Identity.Utilities;
@@ -25,7 +23,6 @@ public class UserDecryptionOptionsBuilderTests
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
private readonly UserDecryptionOptionsBuilder _builder;
private readonly IFeatureService _featureService;
public UserDecryptionOptionsBuilderTests()
{
@@ -33,8 +30,7 @@ public class UserDecryptionOptionsBuilderTests
_deviceRepository = Substitute.For<IDeviceRepository>();
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_loginApprovingClientTypes = Substitute.For<ILoginApprovingClientTypes>();
_featureService = Substitute.For<IFeatureService>();
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes, _featureService);
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes);
var user = new User();
_builder.ForUser(user);
}
@@ -227,43 +223,6 @@ public class UserDecryptionOptionsBuilderTests
Assert.False(result.TrustedDeviceOption?.HasLoginApprovingDevice);
}
/// <summary>
/// This logic has been flagged as part of PM-23174.
/// When removing the server flag, please also remove this test, and remove the FeatureService
/// dependency from this suite and the following test.
/// </summary>
/// <param name="organizationUserType"></param>
/// <param name="ssoConfig"></param>
/// <param name="configurationData"></param>
/// <param name="organization"></param>
/// <param name="organizationUser"></param>
/// <param name="user"></param>
[Theory]
[BitAutoData(OrganizationUserType.Custom)]
public async Task Build_WhenManageResetPasswordPermissions_ShouldReturnHasManageResetPasswordPermissionTrue(
OrganizationUserType organizationUserType,
SsoConfig ssoConfig,
SsoConfigurationData configurationData,
CurrentContextOrganization organization,
[OrganizationUserWithDefaultPermissions] OrganizationUser organizationUser,
User user)
{
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize();
ssoConfig.OrganizationId = organization.Id;
_currentContext.Organizations.Returns([organization]);
_currentContext.ManageResetPassword(organization.Id).Returns(true);
organizationUser.Type = organizationUserType;
organizationUser.OrganizationId = organization.Id;
organizationUser.UserId = user.Id;
organizationUser.SetPermissions(new Permissions() { ManageResetPassword = true });
_organizationUserRepository.GetByOrganizationAsync(ssoConfig.OrganizationId, user.Id).Returns(organizationUser);
var result = await _builder.ForUser(user).WithSso(ssoConfig).BuildAsync();
Assert.True(result.TrustedDeviceOption?.HasManageResetPasswordPermission);
}
[Theory]
[BitAutoData(OrganizationUserType.Custom)]
public async Task Build_WhenManageResetPasswordPermissions_ShouldFetchUserFromRepositoryAndReturnHasManageResetPasswordPermissionTrue(
@@ -274,8 +233,6 @@ public class UserDecryptionOptionsBuilderTests
[OrganizationUserWithDefaultPermissions] OrganizationUser organizationUser,
User user)
{
_featureService.IsEnabled(FeatureFlagKeys.PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword)
.Returns(true);
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize();
ssoConfig.OrganizationId = organization.Id;

View File

@@ -1,6 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">

View File

@@ -2,6 +2,7 @@
using Bit.Core.Entities;
using Bit.Core.Models.Data;
using Bit.Core.Test.AutoFixture.Attributes;
using Bit.Core.Utilities;
using Bit.Core.Vault.Entities;
using Bit.Infrastructure.EFIntegration.Test.AutoFixture;
using Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers;
@@ -279,4 +280,92 @@ public class CipherRepositoryTests
Assert.Equal(Core.Vault.Enums.CipherRepromptType.Password, savedCipher.Reprompt);
}
}
[CiSkippedTheory, EfUserCipherCustomize, BitAutoData]
public async Task ArchiveAsync_SetsArchivesJsonAndBumpsUserAccountRevisionDate(
Cipher cipher,
User user,
List<EfVaultRepo.CipherRepository> suts,
List<EfRepo.UserRepository> efUserRepos)
{
foreach (var sut in suts)
{
var i = suts.IndexOf(sut);
var efUser = await efUserRepos[i].CreateAsync(user);
efUserRepos[i].ClearChangeTracking();
cipher.UserId = efUser.Id;
cipher.OrganizationId = null;
var createdCipher = await sut.CreateAsync(cipher);
sut.ClearChangeTracking();
var archiveUtcNow = await sut.ArchiveAsync(new[] { createdCipher.Id }, efUser.Id);
sut.ClearChangeTracking();
var savedCipher = await sut.GetByIdAsync(createdCipher.Id);
Assert.NotNull(savedCipher);
Assert.Equal(archiveUtcNow, savedCipher.RevisionDate);
Assert.False(string.IsNullOrWhiteSpace(savedCipher.Archives));
var archives = CoreHelpers.LoadClassFromJsonData<Dictionary<Guid, DateTime>>(savedCipher.Archives);
Assert.NotNull(archives);
Assert.True(archives.ContainsKey(efUser.Id));
Assert.Equal(archiveUtcNow, archives[efUser.Id]);
var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id);
Assert.Equal(DateTime.UtcNow.ToShortDateString(), bumpedUser.AccountRevisionDate.ToShortDateString());
}
}
[CiSkippedTheory, EfUserCipherCustomize, BitAutoData]
public async Task UnarchiveAsync_RemovesUserFromArchivesJsonAndBumpsUserAccountRevisionDate(
Cipher cipher,
User user,
List<EfVaultRepo.CipherRepository> suts,
List<EfRepo.UserRepository> efUserRepos)
{
foreach (var sut in suts)
{
var i = suts.IndexOf(sut);
var efUser = await efUserRepos[i].CreateAsync(user);
efUserRepos[i].ClearChangeTracking();
cipher.UserId = efUser.Id;
cipher.OrganizationId = null;
var createdCipher = await sut.CreateAsync(cipher);
sut.ClearChangeTracking();
// Precondition: archived
await sut.ArchiveAsync(new[] { createdCipher.Id }, efUser.Id);
sut.ClearChangeTracking();
var unarchiveUtcNow = await sut.UnarchiveAsync(new[] { createdCipher.Id }, efUser.Id);
sut.ClearChangeTracking();
var savedCipher = await sut.GetByIdAsync(createdCipher.Id);
Assert.NotNull(savedCipher);
Assert.Equal(unarchiveUtcNow, savedCipher.RevisionDate);
// Archives should be null or not contain this user (repo clears string when map empty)
if (!string.IsNullOrWhiteSpace(savedCipher.Archives))
{
var archives = CoreHelpers.LoadClassFromJsonData<Dictionary<Guid, DateTime>>(savedCipher.Archives)
?? new Dictionary<Guid, DateTime>();
Assert.False(archives.ContainsKey(efUser.Id));
}
else
{
Assert.Null(savedCipher.Archives);
}
var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id);
Assert.Equal(DateTime.UtcNow.ToShortDateString(), bumpedUser.AccountRevisionDate.ToShortDateString());
}
}
}

View File

@@ -95,6 +95,7 @@ public static class OrganizationTestHelpers
SyncSeats = false,
UseAutomaticUserConfirmation = true,
UsePhishingBlocker = true,
UseDisableSmAdsForUsers = true,
});
}

View File

@@ -675,6 +675,7 @@ public class OrganizationUserRepositoryTests
UseRiskInsights = false,
UseAdminSponsoredFamilies = false,
UsePhishingBlocker = false,
UseDisableSmAdsForUsers = false,
});
var organizationDomain = new OrganizationDomain

View File

@@ -128,7 +128,6 @@ public class DatabaseDataAttribute : DataAttribute
private void AddDapperServices(IServiceCollection services, Database database)
{
services.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
@@ -141,6 +140,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
}
};
services.AddDapperRepositories(SelfHosted);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);
services.AddSingleton(database);
@@ -160,7 +160,6 @@ public class DatabaseDataAttribute : DataAttribute
private void AddEfServices(IServiceCollection services, Database database)
{
services.SetupEntityFramework(database.ConnectionString, database.Type);
services.AddPasswordManagerEFRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
@@ -169,6 +168,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
},
};
services.AddPasswordManagerEFRepositories(SelfHosted);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);

View File

@@ -3,6 +3,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<UserSecretsId>6570f288-5c2c-47ad-8978-f3da255079c2</UserSecretsId>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1207,10 +1207,110 @@ public class CipherRepositoryTests
// Act
await sutRepository.ArchiveAsync(new List<Guid> { cipher.Id }, user.Id);
// Assert
var archivedCipher = await sutRepository.GetByIdAsync(cipher.Id, user.Id);
Assert.NotNull(archivedCipher);
Assert.NotNull(archivedCipher.ArchivedDate);
// Assert per-user view should show an archive date
var archivedCipherForUser = await sutRepository.GetByIdAsync(cipher.Id, user.Id);
Assert.NotNull(archivedCipherForUser);
Assert.NotNull(archivedCipherForUser.ArchivedDate);
}
[DatabaseTheory, DatabaseData]
public async Task ArchiveAsync_IsPerUserForSharedCipher(
ICipherRepository cipherRepository,
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository,
ICollectionCipherRepository collectionCipherRepository)
{
// Arrange: two users in the same org, both with access to the same cipher
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = $"test+{Guid.NewGuid()}@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = $"test+{Guid.NewGuid()}@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var org = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Organization",
BillingEmail = user1.Email,
Plan = "Test",
});
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
UserId = user1.Id,
OrganizationId = org.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.Owner,
});
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
UserId = user2.Id,
OrganizationId = org.Id,
Status = OrganizationUserStatusType.Confirmed,
Type = OrganizationUserType.User,
});
var sharedCollection = await collectionRepository.CreateAsync(new Collection
{
Name = "Shared Collection",
OrganizationId = org.Id,
});
var cipher = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = org.Id,
Data = "",
});
await collectionCipherRepository.UpdateCollectionsForAdminAsync(
cipher.Id,
org.Id,
new List<Guid> { sharedCollection.Id });
// Give both org users access to the shared collection
await collectionRepository.UpdateUsersAsync(sharedCollection.Id, new List<CollectionAccessSelection>
{
new()
{
Id = orgUser1.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true,
},
new()
{
Id = orgUser2.Id,
HidePasswords = false,
ReadOnly = false,
Manage = true,
},
});
// Act: user1 archives the shared cipher
await cipherRepository.ArchiveAsync(new List<Guid> { cipher.Id }, user1.Id);
// Assert: user1 sees it as archived
var cipherForUser1 = await cipherRepository.GetByIdAsync(cipher.Id, user1.Id);
Assert.NotNull(cipherForUser1);
Assert.NotNull(cipherForUser1.ArchivedDate);
// Assert: user2 still sees it as *not* archived
var cipherForUser2 = await cipherRepository.GetByIdAsync(cipher.Id, user2.Id);
Assert.NotNull(cipherForUser2);
Assert.Null(cipherForUser2.ArchivedDate);
}
[DatabaseTheory, DatabaseData]

View File

@@ -189,7 +189,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
/// Registers a new user to the Identity Application Factory based on the RegisterFinishRequestModel
/// </summary>
/// <param name="requestModel">RegisterFinishRequestModel needed to seed data to the test user</param>
/// <param name="marketingEmails">optional parameter that is tracked during the inital steps of registration.</param>
/// <param name="marketingEmails">optional parameter that is tracked during the initial steps of registration.</param>
/// <returns>returns the newly created user</returns>
public async Task<User> RegisterNewIdentityFactoryUserAsync(
RegisterFinishRequestModel requestModel,

View File

@@ -47,7 +47,7 @@ public abstract class WebApplicationFactoryBase<T> : WebApplicationFactory<T>
/// </remarks>
public bool ManagesDatabase { get; set; } = true;
private readonly List<Action<IServiceCollection>> _configureTestServices = new();
protected readonly List<Action<IServiceCollection>> _configureTestServices = new();
private readonly List<Action<IConfigurationBuilder>> _configureAppConfiguration = new();
public void SubstituteService<TService>(Action<TService> mockService)

View File

@@ -2,13 +2,15 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.10" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Identity\Identity.csproj" />
<ProjectReference Include="..\..\util\Migrator\Migrator.csproj" />

View File

@@ -0,0 +1,40 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
namespace Bit.SeederApi.IntegrationTest;
public static class HttpClientExtensions
{
/// <summary>
/// Sends a POST request with JSON content and attaches the x-play-id header.
/// </summary>
/// <typeparam name="TValue">The type of the value to serialize.</typeparam>
/// <param name="client">The HTTP client.</param>
/// <param name="requestUri">The URI the request is sent to.</param>
/// <param name="value">The value to serialize.</param>
/// <param name="playId">The play ID to attach as x-play-id header.</param>
/// <param name="options">Options to control the behavior during serialization.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
/// <returns>The task object representing the asynchronous operation.</returns>
public static Task<HttpResponseMessage> PostAsJsonAsync<TValue>(
this HttpClient client,
[StringSyntax(StringSyntaxAttribute.Uri)] string? requestUri,
TValue value,
string playId,
JsonSerializerOptions? options = null,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(client);
if (string.IsNullOrWhiteSpace(playId))
{
throw new ArgumentException("Play ID cannot be null or whitespace.", nameof(playId));
}
var content = JsonContent.Create(value, mediaType: null, options);
content.Headers.Remove("x-play-id");
content.Headers.Add("x-play-id", playId);
return client.PostAsync(requestUri, content, cancellationToken);
}
}

View File

@@ -0,0 +1,75 @@
using System.Net;
using Bit.SeederApi.Models.Request;
using Xunit;
namespace Bit.SeederApi.IntegrationTest;
public class QueryControllerTests : IClassFixture<SeederApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly SeederApiApplicationFactory _factory;
public QueryControllerTests(SeederApiApplicationFactory factory)
{
_factory = factory;
_client = _factory.CreateClient();
}
public Task InitializeAsync()
{
return Task.CompletedTask;
}
public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}
[Fact]
public async Task QueryEndpoint_WithValidQueryAndArguments_ReturnsOk()
{
var testEmail = $"emergency-test-{Guid.NewGuid()}@bitwarden.com";
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "EmergencyAccessInviteQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
});
response.EnsureSuccessStatusCode();
var result = await response.Content.ReadAsStringAsync();
Assert.NotNull(result);
var urls = System.Text.Json.JsonSerializer.Deserialize<List<string>>(result);
Assert.NotNull(urls);
// For a non-existent email, we expect an empty list
Assert.Empty(urls);
}
[Fact]
public async Task QueryEndpoint_WithInvalidQueryName_ReturnsNotFound()
{
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "NonExistentQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = "test@example.com" })
});
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Fact]
public async Task QueryEndpoint_WithMissingRequiredField_ReturnsBadRequest()
{
// EmergencyAccessInviteQuery requires 'email' field
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "EmergencyAccessInviteQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { wrongField = "value" })
});
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
}

View File

@@ -0,0 +1,222 @@
using System.Net;
using Bit.SeederApi.Models.Request;
using Bit.SeederApi.Models.Response;
using Xunit;
namespace Bit.SeederApi.IntegrationTest;
public class SeedControllerTests : IClassFixture<SeederApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly SeederApiApplicationFactory _factory;
public SeedControllerTests(SeederApiApplicationFactory factory)
{
_factory = factory;
_client = _factory.CreateClient();
}
public Task InitializeAsync()
{
return Task.CompletedTask;
}
public async Task DisposeAsync()
{
// Clean up any seeded data after each test
await _client.DeleteAsync("/seed");
_client.Dispose();
}
[Fact]
public async Task SeedEndpoint_WithValidScene_ReturnsOk()
{
var testEmail = $"seed-test-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
response.EnsureSuccessStatusCode();
var result = await response.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(result);
Assert.NotNull(result.MangleMap);
Assert.Null(result.Result);
}
[Fact]
public async Task SeedEndpoint_WithInvalidSceneName_ReturnsNotFound()
{
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "NonExistentScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = "test@example.com" })
});
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Fact]
public async Task SeedEndpoint_WithMissingRequiredField_ReturnsBadRequest()
{
// SingleUserScene requires 'email' field
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { wrongField = "value" })
});
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task DeleteEndpoint_WithValidPlayId_ReturnsOk()
{
var testEmail = $"delete-test-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
var deleteResponse = await _client.DeleteAsync($"/seed/{playId}");
deleteResponse.EnsureSuccessStatusCode();
}
[Fact]
public async Task DeleteEndpoint_WithInvalidPlayId_ReturnsOk()
{
// DestroyRecipe is idempotent - returns null for non-existent play IDs
var nonExistentPlayId = Guid.NewGuid().ToString();
var response = await _client.DeleteAsync($"/seed/{nonExistentPlayId}");
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
Assert.Equal($$"""{"playId":"{{nonExistentPlayId}}"}""", content);
}
[Fact]
public async Task DeleteBatchEndpoint_WithValidPlayIds_ReturnsOk()
{
// Create multiple seeds with different play IDs
var playIds = new List<string>();
for (var i = 0; i < 3; i++)
{
var playId = Guid.NewGuid().ToString();
playIds.Add(playId);
var testEmail = $"batch-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
}
// Delete them in batch
var request = new HttpRequestMessage(HttpMethod.Delete, "/seed/batch")
{
Content = JsonContent.Create(playIds)
};
var deleteResponse = await _client.SendAsync(request);
deleteResponse.EnsureSuccessStatusCode();
var result = await deleteResponse.Content.ReadFromJsonAsync<BatchDeleteResponse>();
Assert.NotNull(result);
Assert.Equal("Batch delete completed successfully", result.Message);
}
[Fact]
public async Task DeleteBatchEndpoint_WithSomeInvalidIds_ReturnsOk()
{
// DestroyRecipe is idempotent - batch delete succeeds even with non-existent IDs
// Create one valid seed with a play ID
var validPlayId = Guid.NewGuid().ToString();
var testEmail = $"batch-partial-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, validPlayId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
// Try to delete with mix of valid and invalid IDs
var playIds = new List<string> { validPlayId, Guid.NewGuid().ToString(), Guid.NewGuid().ToString() };
var request = new HttpRequestMessage(HttpMethod.Delete, "/seed/batch")
{
Content = JsonContent.Create(playIds)
};
var deleteResponse = await _client.SendAsync(request);
deleteResponse.EnsureSuccessStatusCode();
var result = await deleteResponse.Content.ReadFromJsonAsync<BatchDeleteResponse>();
Assert.NotNull(result);
Assert.Equal("Batch delete completed successfully", result.Message);
}
[Fact]
public async Task DeleteAllEndpoint_DeletesAllSeededData()
{
// Create multiple seeds
for (var i = 0; i < 2; i++)
{
var playId = Guid.NewGuid().ToString();
var testEmail = $"deleteall-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
}
// Delete all
var deleteResponse = await _client.DeleteAsync("/seed");
Assert.Equal(HttpStatusCode.NoContent, deleteResponse.StatusCode);
}
[Fact]
public async Task SeedEndpoint_VerifyResponseContainsMangleMapAndResult()
{
var testEmail = $"verify-response-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
response.EnsureSuccessStatusCode();
var jsonString = await response.Content.ReadAsStringAsync();
// Verify the response contains MangleMap and Result fields
Assert.Contains("mangleMap", jsonString, StringComparison.OrdinalIgnoreCase);
Assert.Contains("result", jsonString, StringComparison.OrdinalIgnoreCase);
}
private class BatchDeleteResponse
{
public string? Message { get; set; }
}
}

View File

@@ -0,0 +1,29 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\util\SeederApi\SeederApi.csproj" />
<ProjectReference Include="..\..\util\Seeder\Seeder.csproj" />
<ProjectReference Include="..\IntegrationTestCommon\IntegrationTestCommon.csproj" />
<Content Include="..\..\util\SeederApi\appsettings.*.json">
<Link>%(RecursiveDir)%(Filename)%(Extension)</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,18 @@
using Bit.Core.Services;
using Bit.IntegrationTestCommon;
using Bit.IntegrationTestCommon.Factories;
namespace Bit.SeederApi.IntegrationTest;
public class SeederApiApplicationFactory : WebApplicationFactoryBase<Startup>
{
public SeederApiApplicationFactory()
{
TestDatabase = new SqliteTestDatabase();
_configureTestServices.Add(serviceCollection =>
{
serviceCollection.AddSingleton<IPlayIdService, NeverPlayIdServices>();
serviceCollection.AddHttpContextAccessor();
});
}
}

View File

@@ -0,0 +1 @@
[assembly: CaptureTrace]

View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Using Include="Xunit" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.11.0" />
<PackageReference Include="xunit.v3" Version="3.0.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.1.4" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.10" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\util\Server\Server.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,45 @@
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.AspNetCore.TestHost;
namespace Bit.Server.IntegrationTest;
public class Server : WebApplicationFactory<Program>
{
public string? ContentRoot { get; set; }
public string? WebRoot { get; set; }
public bool ServeUnknown { get; set; }
public bool? WebVault { get; set; }
public string? AppIdLocation { get; set; }
protected override IWebHostBuilder? CreateWebHostBuilder()
{
var args = new List<string>
{
"/contentRoot",
ContentRoot ?? "",
"/webRoot",
WebRoot ?? "",
"/serveUnknown",
ServeUnknown.ToString().ToLowerInvariant(),
};
if (WebVault.HasValue)
{
args.Add("/webVault");
args.Add(WebVault.Value.ToString().ToLowerInvariant());
}
if (!string.IsNullOrEmpty(AppIdLocation))
{
args.Add("/appIdLocation");
args.Add(AppIdLocation);
}
var builder = WebHostBuilderFactory.CreateFromTypesAssemblyEntryPoint<Program>([.. args])
?? throw new InvalidProgramException("Could not create builder from assembly.");
builder.UseSetting("TEST_CONTENTROOT_SERVER", ContentRoot);
return builder;
}
}

View File

@@ -0,0 +1,102 @@
using System.Net;
using System.Runtime.CompilerServices;
namespace Bit.Server.IntegrationTest;
public class ServerTests
{
[Fact]
public async Task AttachmentsStyleUse()
{
using var tempDir = new TempDir();
await tempDir.WriteAsync("my-file.txt", "Hello!");
using var server = new Server
{
ContentRoot = tempDir.Info.FullName,
WebRoot = ".",
ServeUnknown = true,
};
var client = server.CreateClient();
var response = await client.GetAsync("/my-file.txt", TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("Hello!", await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken));
}
[Fact]
public async Task WebVaultStyleUse()
{
using var tempDir = new TempDir();
await tempDir.WriteAsync("index.html", "<html></html>");
await tempDir.WriteAsync(Path.Join("app", "file.js"), "AppStuff");
await tempDir.WriteAsync(Path.Join("locales", "file.json"), "LocalesStuff");
await tempDir.WriteAsync(Path.Join("fonts", "file.ttf"), "FontsStuff");
await tempDir.WriteAsync(Path.Join("connectors", "file.js"), "ConnectorsStuff");
await tempDir.WriteAsync(Path.Join("scripts", "file.js"), "ScriptsStuff");
await tempDir.WriteAsync(Path.Join("images", "file.avif"), "ImagesStuff");
await tempDir.WriteAsync(Path.Join("test", "file.json"), "{}");
using var server = new Server
{
ContentRoot = tempDir.Info.FullName,
WebRoot = ".",
ServeUnknown = false,
WebVault = true,
AppIdLocation = Path.Join(tempDir.Info.FullName, "test", "file.json"),
};
var client = server.CreateClient();
// Going to root should return the default file
var response = await client.GetAsync("", TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("<html></html>", await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken));
// No caching on the default document
Assert.Null(response.Headers.CacheControl?.MaxAge);
await ExpectMaxAgeAsync("app/file.js", TimeSpan.FromDays(14));
await ExpectMaxAgeAsync("locales/file.json", TimeSpan.FromDays(14));
await ExpectMaxAgeAsync("fonts/file.ttf", TimeSpan.FromDays(14));
await ExpectMaxAgeAsync("connectors/file.js", TimeSpan.FromDays(14));
await ExpectMaxAgeAsync("scripts/file.js", TimeSpan.FromDays(14));
await ExpectMaxAgeAsync("images/file.avif", TimeSpan.FromDays(7));
response = await client.GetAsync("app-id.json", TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("application/json", response.Content.Headers.ContentType?.MediaType);
async Task ExpectMaxAgeAsync(string path, TimeSpan maxAge)
{
response = await client.GetAsync(path);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.NotNull(response.Headers.CacheControl);
Assert.Equal(maxAge, response.Headers.CacheControl.MaxAge);
}
}
private class TempDir([CallerMemberName] string test = null!) : IDisposable
{
public DirectoryInfo Info { get; } = Directory.CreateTempSubdirectory(test);
public void Dispose()
{
Info.Delete(recursive: true);
}
public async Task WriteAsync(string fileName, string content)
{
var fullPath = Path.Join(Info.FullName, fileName);
var directory = Path.GetDirectoryName(fullPath);
if (directory != null)
{
Directory.CreateDirectory(directory);
}
await File.WriteAllTextAsync(fullPath, content, TestContext.Current.CancellationToken);
}
}
}

View File

@@ -0,0 +1,102 @@
using Bit.Core.Services;
using Bit.SharedWeb.Utilities;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Hosting;
using NSubstitute;
namespace SharedWeb.Test;
public class PlayIdMiddlewareTests
{
private readonly PlayIdService _playIdService;
private readonly RequestDelegate _next;
private readonly PlayIdMiddleware _middleware;
public PlayIdMiddlewareTests()
{
var hostEnvironment = Substitute.For<IHostEnvironment>();
hostEnvironment.EnvironmentName.Returns(Environments.Development);
_playIdService = new PlayIdService(hostEnvironment);
_next = Substitute.For<RequestDelegate>();
_middleware = new PlayIdMiddleware(_next);
}
[Fact]
public async Task Invoke_WithValidPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play-id";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play-id", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithoutPlayIdHeader_CallsNext()
{
var context = new DefaultHttpContext();
await _middleware.Invoke(context, _playIdService);
Assert.Null(_playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Theory]
[InlineData("")]
[InlineData(" ")]
[InlineData("\t")]
public async Task Invoke_WithEmptyOrWhitespacePlayId_Returns400(string playId)
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Request.Headers["x-play-id"] = playId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdExceedingMaxLength_Returns400()
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
var longPlayId = new string('a', 257); // Exceeds 256 character limit
context.Request.Headers["x-play-id"] = longPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdAtMaxLength_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
var maxLengthPlayId = new string('a', 256); // Exactly 256 characters
context.Request.Headers["x-play-id"] = maxLengthPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(maxLengthPlayId, _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithSpecialCharactersInPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play_id.123";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play_id.123", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
}