1
0
mirror of https://github.com/bitwarden/server synced 2026-03-01 10:51:26 +00:00

Merge branch 'main' into auth/pm-29584/create-email-for-emergency-access-removal

This commit is contained in:
enmande
2026-01-14 09:46:44 -05:00
185 changed files with 19973 additions and 2304 deletions

View File

@@ -4,7 +4,6 @@ using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Api.IntegrationTest.Factories;
using Bit.Api.IntegrationTest.Helpers;
using Bit.Api.Models.Response;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
@@ -14,8 +13,6 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Xunit;
namespace Bit.Api.IntegrationTest.AdminConsole.Controllers;
@@ -32,12 +29,6 @@ public class OrganizationUserControllerBulkRevokeTests : IClassFixture<ApiApplic
public OrganizationUserControllerBulkRevokeTests(ApiApplicationFactory apiFactory)
{
_factory = apiFactory;
_factory.SubstituteService<IFeatureService>(featureService =>
{
featureService
.IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2)
.Returns(true);
});
_client = _factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client);
}

View File

@@ -1,6 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -3,6 +3,8 @@ using Bit.Api.Billing.Models.Requests.Storage;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Licenses.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Subscriptions.Commands;
using Bit.Core.Billing.Subscriptions.Queries;
using Bit.Core.Entities;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
@@ -29,9 +31,11 @@ public class AccountBillingVNextControllerTests
_sut = new AccountBillingVNextController(
Substitute.For<Core.Billing.Payment.Commands.ICreateBitPayInvoiceForCreditCommand>(),
Substitute.For<Core.Billing.Premium.Commands.ICreatePremiumCloudHostedSubscriptionCommand>(),
Substitute.For<IGetBitwardenSubscriptionQuery>(),
Substitute.For<Core.Billing.Payment.Queries.IGetCreditQuery>(),
Substitute.For<Core.Billing.Payment.Queries.IGetPaymentMethodQuery>(),
_getUserLicenseQuery,
Substitute.For<IReinstateSubscriptionCommand>(),
Substitute.For<Core.Billing.Payment.Commands.IUpdatePaymentMethodCommand>(),
_updatePremiumStorageCommand,
_upgradePremiumToOrganizationCommand);
@@ -63,7 +67,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -83,7 +87,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -103,7 +107,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -123,7 +127,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -143,7 +147,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -163,7 +167,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BadRequest(errorMessage));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var badRequestResult = Assert.IsAssignableFrom<IResult>(result);
@@ -182,7 +186,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -201,7 +205,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -220,7 +224,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);
@@ -239,7 +243,7 @@ public class AccountBillingVNextControllerTests
.Returns(new BillingCommandResult<None>(new None()));
// Act
var result = await _sut.UpdateStorageAsync(user, request);
var result = await _sut.UpdateSubscriptionStorageAsync(user, request);
// Assert
var okResult = Assert.IsAssignableFrom<IResult>(result);

View File

@@ -26,6 +26,7 @@ public class SutProvider<TSut> : ISutProvider
public TSut Sut { get; private set; }
public Type SutType => typeof(TSut);
public IFixture Fixture => _fixture;
public SutProvider() : this(new Fixture()) { }
@@ -65,6 +66,19 @@ public class SutProvider<TSut> : ISutProvider
return this;
}
/// <summary>
/// Creates and registers a dependency to be injected when the sut is created.
/// </summary>
/// <typeparam name="TDep">The Dependency type to create</typeparam>
/// <param name="parameterName">The (optional) parameter name to register the dependency under</param>
/// <returns>The created dependency value</returns>
public TDep CreateDependency<TDep>(string parameterName = "")
{
var dependency = _fixture.Create<TDep>();
SetDependency(dependency, parameterName);
return dependency;
}
/// <summary>
/// Gets a dependency of the sut. Can only be called after the dependency has been set, either explicitly with
/// <see cref="SetDependency{T}"/> or automatically with <see cref="Create"/>.

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<RootNamespace>Bit.Test.Common</RootNamespace>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,7 +1,10 @@
using System.Text.Json;
using System.Reflection;
using System.Text.Json;
using System.Text.RegularExpressions;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Billing.Organizations.Models;
using Bit.Test.Common.Helpers;
using Xunit;
@@ -115,4 +118,105 @@ public class OrganizationTests
Assert.True(organization.UseDisableSmAdsForUsers);
}
[Fact]
public void UpdateFromLicense_AppliesAllLicenseProperties()
{
// This test ensures that when a new property is added to OrganizationLicense,
// it is also applied to the Organization in UpdateFromLicense().
// This is the fourth step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Get all public properties from OrganizationLicense
var licenseProperties = typeof(OrganizationLicense)
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Select(p => p.Name)
.ToHashSet();
// 2. Define properties that don't need to be applied to Organization
var excludedProperties = new HashSet<string>
{
// Internal/computed properties
"SignatureBytes", // Computed from Signature property
"ValidLicenseVersion", // Internal property, not serialized
"CurrentLicenseFileVersion", // Constant field, not an instance property
"Hash", // Signature-related, not applied to org
"Signature", // Signature-related, not applied to org
"Token", // The JWT itself, not applied to org
"Version", // License version, not stored on org
// Properties intentionally excluded from UpdateFromLicense
"Id", // Self-hosted org has its own unique Guid
"MaxStorageGb", // Not enforced for self-hosted (per comment in UpdateFromLicense)
// Properties not stored on Organization model
"LicenseType", // Not a property on Organization
"InstallationId", // Not a property on Organization
"Issued", // Not a property on Organization
"Refresh", // Not a property on Organization
"ExpirationWithoutGracePeriod", // Not a property on Organization
"Trial", // Not a property on Organization
"Expires", // Mapped to ExpirationDate on Organization (different name)
// Deprecated properties not applied
"LimitCollectionCreationDeletion", // Deprecated, not applied
"AllowAdminAccessToAllCollectionItems", // Deprecated, not applied
};
// 3. Get properties that should be applied
var propertiesThatShouldBeApplied = licenseProperties
.Except(excludedProperties)
.ToHashSet();
// 4. Read Organization.UpdateFromLicense source code
var organizationSourcePath = Path.Combine(
Directory.GetCurrentDirectory(),
"..", "..", "..", "..", "..", "src", "Core", "AdminConsole", "Entities", "Organization.cs");
var sourceCode = File.ReadAllText(organizationSourcePath);
// 5. Find all property assignments in UpdateFromLicense method
// Pattern matches: PropertyName = license.PropertyName
// This regex looks for assignments like "Name = license.Name" or "ExpirationDate = license.Expires"
var assignmentPattern = @"(\w+)\s*=\s*license\.(\w+)";
var matches = Regex.Matches(sourceCode, assignmentPattern);
var appliedProperties = new HashSet<string>();
foreach (Match match in matches)
{
// Get the license property name (right side of assignment)
var licensePropertyName = match.Groups[2].Value;
appliedProperties.Add(licensePropertyName);
}
// Special case: Expires is mapped to ExpirationDate
if (appliedProperties.Contains("Expires"))
{
appliedProperties.Add("Expires"); // Already added, but being explicit
}
// 6. Find missing applications
var missingApplications = propertiesThatShouldBeApplied
.Except(appliedProperties)
.OrderBy(p => p)
.ToList();
// 7. Build error message with guidance
var errorMessage = "";
if (missingApplications.Any())
{
errorMessage = $"The following OrganizationLicense properties are NOT applied to Organization in UpdateFromLicense():\n";
errorMessage += string.Join("\n", missingApplications.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following lines to Organization.UpdateFromLicense():\n";
foreach (var prop in missingApplications)
{
errorMessage += $" {prop} = license.{prop};\n";
}
errorMessage += "\nNote: If the property maps to a different name on Organization (like Expires → ExpirationDate), adjust accordingly.";
}
// 8. Assert - if this fails, the error message guides the developer to add the application
Assert.True(
!missingApplications.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -0,0 +1,68 @@
using System.Reflection;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Organizations.Models;
using Xunit;
namespace Bit.Core.Test.Billing.Licenses;
public class LicenseConstantsTests
{
[Fact]
public void OrganizationLicenseConstants_HasConstantForEveryLicenseProperty()
{
// This test ensures that when a new property is added to OrganizationLicense,
// a corresponding constant is added to OrganizationLicenseConstants.
// This is the first step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Get all public properties from OrganizationLicense
var licenseProperties = typeof(OrganizationLicense)
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Select(p => p.Name)
.ToHashSet();
// 2. Get all constants from OrganizationLicenseConstants
var constants = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToHashSet();
// 3. Define properties that don't need constants (internal/computed/non-claims properties)
var excludedProperties = new HashSet<string>
{
"SignatureBytes", // Computed from Signature property
"ValidLicenseVersion", // Internal property, not serialized
"CurrentLicenseFileVersion", // Constant field, not an instance property
"Hash", // Signature-related, not in claims system
"Signature", // Signature-related, not in claims system
"Token", // The JWT itself, not a claim within the token
"Version" // Not in claims system (only in deprecated property-based licenses)
};
// 4. Find license properties without corresponding constants
var propertiesWithoutConstants = licenseProperties
.Except(constants)
.Except(excludedProperties)
.OrderBy(p => p)
.ToList();
// 5. Build error message with guidance
var errorMessage = "";
if (propertiesWithoutConstants.Any())
{
errorMessage = $"The following OrganizationLicense properties don't have constants in OrganizationLicenseConstants:\n";
errorMessage += string.Join("\n", propertiesWithoutConstants.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following constants to OrganizationLicenseConstants:\n";
foreach (var prop in propertiesWithoutConstants)
{
errorMessage += $" public const string {prop} = nameof({prop});\n";
}
}
// 6. Assert - if this fails, the error message guides the developer to add the constant
Assert.True(
!propertiesWithoutConstants.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -0,0 +1,92 @@
using System.Reflection;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Models;
using Bit.Core.Billing.Licenses.Services.Implementations;
using Bit.Core.Models.Business;
using Bit.Test.Common.AutoFixture.Attributes;
using Xunit;
namespace Bit.Core.Test.Billing.Licenses.Services.Implementations;
public class OrganizationLicenseClaimsFactoryTests
{
[Theory, BitAutoData]
public async Task GenerateClaims_CreatesClaimsForAllConstants(Organization organization)
{
// This test ensures that when a constant is added to OrganizationLicenseConstants,
// it is also added to the OrganizationLicenseClaimsFactory to generate claims.
// This is the second step in the license synchronization pipeline:
// Property → Constant → Claim → Extraction → Application
// 1. Populate all nullable properties to ensure claims can be generated
// The factory only adds claims for properties that have values
organization.Name = "Test Organization";
organization.BillingEmail = "billing@test.com";
organization.BusinessName = "Test Business";
organization.Plan = "Enterprise";
organization.LicenseKey = "test-license-key";
organization.Seats = 100;
organization.MaxCollections = 50;
organization.MaxStorageGb = 10;
organization.SmSeats = 25;
organization.SmServiceAccounts = 10;
organization.ExpirationDate = DateTime.UtcNow.AddYears(1); // Ensure org is not expired
// Create a LicenseContext with a minimal SubscriptionInfo to trigger conditional claims
// ExpirationWithoutGracePeriod is only generated for active, non-trial, annual subscriptions
var licenseContext = new LicenseContext
{
InstallationId = Guid.NewGuid(),
SubscriptionInfo = new SubscriptionInfo
{
Subscription = new SubscriptionInfo.BillingSubscription(null!)
{
TrialEndDate = DateTime.UtcNow.AddDays(-30), // Trial ended in the past
PeriodStartDate = DateTime.UtcNow,
PeriodEndDate = DateTime.UtcNow.AddDays(365), // Annual subscription (>180 days)
Status = "active"
}
}
};
// 2. Generate claims
var factory = new OrganizationLicenseClaimsFactory();
var claims = await factory.GenerateClaims(organization, licenseContext);
// 3. Get all constants from OrganizationLicenseConstants
var allConstants = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToHashSet();
// 4. Get claim types from generated claims
var generatedClaimTypes = claims.Select(c => c.Type).ToHashSet();
// 5. Find constants that don't have corresponding claims
var constantsWithoutClaims = allConstants
.Except(generatedClaimTypes)
.OrderBy(c => c)
.ToList();
// 6. Build error message with guidance
var errorMessage = "";
if (constantsWithoutClaims.Any())
{
errorMessage = $"The following constants in OrganizationLicenseConstants are NOT generated as claims in OrganizationLicenseClaimsFactory:\n";
errorMessage += string.Join("\n", constantsWithoutClaims.Select(c => $" - {c}"));
errorMessage += "\n\nPlease add the following claims to OrganizationLicenseClaimsFactory.GenerateClaims():\n";
foreach (var constant in constantsWithoutClaims)
{
errorMessage += $" new(nameof(OrganizationLicenseConstants.{constant}), entity.{constant}.ToString()),\n";
}
errorMessage += "\nNote: If the property is nullable, you may need to add it conditionally.";
}
// 7. Assert - if this fails, the error message guides the developer to add claim generation
Assert.True(
!constantsWithoutClaims.Any(),
$"\n{errorMessage}");
}
}

View File

@@ -1,9 +1,14 @@
using System.Security.Claims;
using System.Reflection;
using System.Security.Claims;
using System.Text.RegularExpressions;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Services;
using Bit.Core.Settings;
@@ -99,6 +104,320 @@ public class UpdateOrganizationLicenseCommandTests
}
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_WithClaimsPrincipal_ExtractsAllPropertiesFromClaims(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup license for CanUse validation
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-1);
license.Expires = DateTime.Now.AddDays(1);
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.InstallationId = globalSettings.Installation.Id;
license.LicenseType = LicenseType.Organization;
license.Token = "test-token"; // Indicates this is a claims-based license
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
// Create a ClaimsPrincipal with all organization license claims
var claims = new List<Claim>
{
new(OrganizationLicenseConstants.LicenseType, ((int)LicenseType.Organization).ToString()),
new(OrganizationLicenseConstants.InstallationId, globalSettings.Installation.Id.ToString()),
new(OrganizationLicenseConstants.Name, "Test Organization"),
new(OrganizationLicenseConstants.BillingEmail, "billing@test.com"),
new(OrganizationLicenseConstants.BusinessName, "Test Business"),
new(OrganizationLicenseConstants.PlanType, ((int)PlanType.EnterpriseAnnually).ToString()),
new(OrganizationLicenseConstants.Seats, "100"),
new(OrganizationLicenseConstants.MaxCollections, "50"),
new(OrganizationLicenseConstants.UsePolicies, "true"),
new(OrganizationLicenseConstants.UseSso, "true"),
new(OrganizationLicenseConstants.UseKeyConnector, "true"),
new(OrganizationLicenseConstants.UseScim, "true"),
new(OrganizationLicenseConstants.UseGroups, "true"),
new(OrganizationLicenseConstants.UseDirectory, "true"),
new(OrganizationLicenseConstants.UseEvents, "true"),
new(OrganizationLicenseConstants.UseTotp, "true"),
new(OrganizationLicenseConstants.Use2fa, "true"),
new(OrganizationLicenseConstants.UseApi, "true"),
new(OrganizationLicenseConstants.UseResetPassword, "true"),
new(OrganizationLicenseConstants.Plan, "Enterprise"),
new(OrganizationLicenseConstants.SelfHost, "true"),
new(OrganizationLicenseConstants.UsersGetPremium, "true"),
new(OrganizationLicenseConstants.UseCustomPermissions, "true"),
new(OrganizationLicenseConstants.Enabled, "true"),
new(OrganizationLicenseConstants.Expires, DateTime.Now.AddDays(1).ToString("O")),
new(OrganizationLicenseConstants.LicenseKey, "test-license-key"),
new(OrganizationLicenseConstants.UsePasswordManager, "true"),
new(OrganizationLicenseConstants.UseSecretsManager, "true"),
new(OrganizationLicenseConstants.SmSeats, "25"),
new(OrganizationLicenseConstants.SmServiceAccounts, "10"),
new(OrganizationLicenseConstants.UseRiskInsights, "true"),
new(OrganizationLicenseConstants.UseOrganizationDomains, "true"),
new(OrganizationLicenseConstants.UseAdminSponsoredFamilies, "true"),
new(OrganizationLicenseConstants.UseAutomaticUserConfirmation, "true"),
new(OrganizationLicenseConstants.UseDisableSmAdsForUsers, "true"),
new(OrganizationLicenseConstants.UsePhishingBlocker, "true"),
new(OrganizationLicenseConstants.MaxStorageGb, "5"),
new(OrganizationLicenseConstants.Issued, DateTime.Now.AddDays(-1).ToString("O")),
new(OrganizationLicenseConstants.Refresh, DateTime.Now.AddMonths(1).ToString("O")),
new(OrganizationLicenseConstants.ExpirationWithoutGracePeriod, DateTime.Now.AddMonths(12).ToString("O")),
new(OrganizationLicenseConstants.Trial, "false"),
new(OrganizationLicenseConstants.LimitCollectionCreationDeletion, "true"),
new(OrganizationLicenseConstants.AllowAdminAccessToAllCollectionItems, "true")
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(claims));
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns(claimsPrincipal);
// Setup selfHostedOrg for CanUseLicense validation
selfHostedOrg.OccupiedSeatCount = 50; // Less than the 100 seats in the license
selfHostedOrg.CollectionCount = 10; // Less than the 50 max collections in the license
selfHostedOrg.GroupCount = 1;
selfHostedOrg.UseGroups = true;
selfHostedOrg.UsePolicies = true;
selfHostedOrg.UseSso = true;
selfHostedOrg.UseKeyConnector = true;
selfHostedOrg.UseScim = true;
selfHostedOrg.UseCustomPermissions = true;
selfHostedOrg.UseResetPassword = true;
try
{
await sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null);
// Assertion: license file should be written to disk
var filePath = Path.Combine(LicenseDirectory, "organization", $"{selfHostedOrg.Id}.json");
await using var fs = File.OpenRead(filePath);
var licenseFromFile = await JsonSerializer.DeserializeAsync<OrganizationLicense>(fs);
AssertHelper.AssertPropertyEqual(license, licenseFromFile, "SignatureBytes");
// Assertion: organization should be updated with ALL properties extracted from claims
await sutProvider.GetDependency<IOrganizationService>()
.Received(1)
.ReplaceAndUpdateCacheAsync(Arg.Is<Organization>(org =>
org.Name == "Test Organization" &&
org.BillingEmail == "billing@test.com" &&
org.BusinessName == "Test Business" &&
org.PlanType == PlanType.EnterpriseAnnually &&
org.Seats == 100 &&
org.MaxCollections == 50 &&
org.UsePolicies == true &&
org.UseSso == true &&
org.UseKeyConnector == true &&
org.UseScim == true &&
org.UseGroups == true &&
org.UseDirectory == true &&
org.UseEvents == true &&
org.UseTotp == true &&
org.Use2fa == true &&
org.UseApi == true &&
org.UseResetPassword == true &&
org.Plan == "Enterprise" &&
org.SelfHost == true &&
org.UsersGetPremium == true &&
org.UseCustomPermissions == true &&
org.Enabled == true &&
org.LicenseKey == "test-license-key" &&
org.UsePasswordManager == true &&
org.UseSecretsManager == true &&
org.SmSeats == 25 &&
org.SmServiceAccounts == 10 &&
org.UseRiskInsights == true &&
org.UseOrganizationDomains == true &&
org.UseAdminSponsoredFamilies == true &&
org.UseAutomaticUserConfirmation == true &&
org.UseDisableSmAdsForUsers == true &&
org.UsePhishingBlocker == true));
}
finally
{
// Clean up temporary directory
if (Directory.Exists(OrganizationLicenseDirectory.Value))
{
Directory.Delete(OrganizationLicenseDirectory.Value, true);
}
}
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_WrongInstallationIdInClaims_ThrowsBadRequestException(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup license for CanUse validation
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-1);
license.Expires = DateTime.Now.AddDays(1);
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.LicenseType = LicenseType.Organization;
license.Token = "test-token"; // Indicates this is a claims-based license
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
// Create a ClaimsPrincipal with WRONG installation ID
var wrongInstallationId = Guid.NewGuid(); // Different from globalSettings.Installation.Id
var claims = new List<Claim>
{
new(OrganizationLicenseConstants.LicenseType, ((int)LicenseType.Organization).ToString()),
new(OrganizationLicenseConstants.InstallationId, wrongInstallationId.ToString()),
new(OrganizationLicenseConstants.Enabled, "true"),
new(OrganizationLicenseConstants.SelfHost, "true")
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(claims));
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns(claimsPrincipal);
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null));
Assert.Contains("The installation ID does not match the current installation.", exception.Message);
// Verify organization was NOT saved
await sutProvider.GetDependency<IOrganizationService>()
.DidNotReceive()
.ReplaceAndUpdateCacheAsync(Arg.Any<Organization>());
}
[Theory, BitAutoData]
public async Task UpdateLicenseAsync_ExpiredLicenseWithoutClaims_ThrowsBadRequestException(
SelfHostedOrganizationDetails selfHostedOrg,
OrganizationLicense license,
SutProvider<UpdateOrganizationLicenseCommand> sutProvider)
{
var globalSettings = sutProvider.GetDependency<IGlobalSettings>();
globalSettings.LicenseDirectory = LicenseDirectory;
globalSettings.SelfHosted = true;
// Setup legacy license (no Token, no claims)
license.Token = null; // Legacy license
license.Enabled = true;
license.Issued = DateTime.Now.AddDays(-2);
license.Expires = DateTime.Now.AddDays(-1); // Expired yesterday
license.Version = OrganizationLicense.CurrentLicenseFileVersion;
license.InstallationId = globalSettings.Installation.Id;
license.LicenseType = LicenseType.Organization;
license.SelfHost = true;
sutProvider.GetDependency<ILicensingService>().VerifyLicense(license).Returns(true);
sutProvider.GetDependency<ILicensingService>()
.GetClaimsPrincipalFromLicense(license)
.Returns((ClaimsPrincipal)null); // No claims for legacy license
// Passing values for SelfHostedOrganizationDetails.CanUseLicense
license.Seats = null;
license.MaxCollections = null;
license.UseGroups = true;
license.UsePolicies = true;
license.UseSso = true;
license.UseKeyConnector = true;
license.UseScim = true;
license.UseCustomPermissions = true;
license.UseResetPassword = true;
// Act & Assert
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateLicenseAsync(selfHostedOrg, license, null));
Assert.Contains("The license has expired.", exception.Message);
// Verify organization was NOT saved
await sutProvider.GetDependency<IOrganizationService>()
.DidNotReceive()
.ReplaceAndUpdateCacheAsync(Arg.Any<Organization>());
}
[Fact]
public async Task UpdateLicenseAsync_ExtractsAllClaimsBasedProperties_WhenClaimsPrincipalProvided()
{
// This test ensures that when new properties are added to OrganizationLicense,
// they are automatically extracted from JWT claims in UpdateOrganizationLicenseCommand.
// If a new constant is added to OrganizationLicenseConstants but not extracted,
// this test will fail with a clear message showing which properties are missing.
// 1. Get all OrganizationLicenseConstants
var constantFields = typeof(OrganizationLicenseConstants)
.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.GetField)
.Where(f => f.IsLiteral && !f.IsInitOnly)
.Select(f => f.GetValue(null) as string)
.ToList();
// 2. Define properties that should be excluded (not claims-based or intentionally not extracted)
var excludedProperties = new HashSet<string>
{
"Version", // Not in claims system (only in deprecated property-based licenses)
"Hash", // Signature-related, not extracted from claims
"Signature", // Signature-related, not extracted from claims
"SignatureBytes", // Computed from Signature, not a claim
"Token", // The JWT itself, not extracted from claims
"Id" // Cloud org ID from license, not used - self-hosted org has its own separate ID
};
// 3. Get properties that should be extracted from claims
var propertiesThatShouldBeExtracted = constantFields
.Where(c => !excludedProperties.Contains(c))
.ToHashSet();
// 4. Read UpdateOrganizationLicenseCommand source code
var commandSourcePath = Path.Combine(
Directory.GetCurrentDirectory(),
"..", "..", "..", "..", "..",
"src", "Core", "Billing", "Organizations", "Commands", "UpdateOrganizationLicenseCommand.cs");
var sourceCode = await File.ReadAllTextAsync(commandSourcePath);
// 5. Find all GetValue calls that extract properties from claims
// Pattern matches: license.PropertyName = claimsPrincipal.GetValue<Type>(OrganizationLicenseConstants.PropertyName)
var extractedProperties = new HashSet<string>();
var getValuePattern = @"claimsPrincipal\.GetValue<[^>]+>\(OrganizationLicenseConstants\.(\w+)\)";
var matches = Regex.Matches(sourceCode, getValuePattern);
foreach (Match match in matches)
{
extractedProperties.Add(match.Groups[1].Value);
}
// 6. Find missing extractions
var missingExtractions = propertiesThatShouldBeExtracted
.Except(extractedProperties)
.OrderBy(p => p)
.ToList();
// 7. Build error message with guidance if there are missing extractions
var errorMessage = "";
if (missingExtractions.Any())
{
errorMessage = $"The following constants in OrganizationLicenseConstants are NOT extracted from claims in UpdateOrganizationLicenseCommand:\n";
errorMessage += string.Join("\n", missingExtractions.Select(p => $" - {p}"));
errorMessage += "\n\nPlease add the following lines to UpdateOrganizationLicenseCommand.cs in the 'if (claimsPrincipal != null)' block:\n";
foreach (var prop in missingExtractions)
{
errorMessage += $" license.{prop} = claimsPrincipal.GetValue<TYPE>(OrganizationLicenseConstants.{prop});\n";
}
}
// 8. Assert - if this fails, the error message guides the developer to add the extraction
// Note: We don't check for "extra extractions" because that would be a compile error
// (can't reference OrganizationLicenseConstants.Foo if Foo doesn't exist)
Assert.True(
!missingExtractions.Any(),
$"\n{errorMessage}");
}
// Wrapper to compare 2 objects that are different types
private bool AssertPropertyEqual(OrganizationLicense expected, Organization actual, params string[] excludedPropertyStrings)
{

View File

@@ -8,6 +8,7 @@ using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
@@ -29,6 +30,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands;
public class CreatePremiumCloudHostedSubscriptionCommandTests
{
private readonly IBraintreeGateway _braintreeGateway = Substitute.For<IBraintreeGateway>();
private readonly IBraintreeService _braintreeService = Substitute.For<IBraintreeService>();
private readonly IGlobalSettings _globalSettings = Substitute.For<IGlobalSettings>();
private readonly ISetupIntentCache _setupIntentCache = Substitute.For<ISetupIntentCache>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
@@ -59,6 +61,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
_command = new CreatePremiumCloudHostedSubscriptionCommand(
_braintreeGateway,
_braintreeService,
_globalSettings,
_setupIntentCache,
_stripeAdapter,
@@ -235,11 +238,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.LatestInvoiceId = "in_123";
var mockInvoice = Substitute.For<Invoice>();
@@ -258,6 +265,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any<SubscriptionCreateOptions>());
await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts => opts.AutoAdvance == false));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
@@ -456,11 +466,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "incomplete";
mockSubscription.LatestInvoiceId = "in_123";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
@@ -487,6 +501,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
Assert.True(user.Premium);
Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts => opts.AutoAdvance == false));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
}
[Theory, BitAutoData]
@@ -559,11 +576,15 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
mockCustomer.Metadata = new Dictionary<string, string>
{
[Core.Billing.Utilities.BraintreeCustomerIdKey] = "bt_customer_123"
};
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active"; // PayPal + active doesn't match pattern
mockSubscription.LatestInvoiceId = "in_123";
mockSubscription.Items = new StripeList<SubscriptionItem>
{
Data =
@@ -590,6 +611,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests
Assert.True(result.IsT0);
Assert.False(user.Premium);
Assert.Null(user.PremiumExpirationDate);
await _stripeAdapter.Received(1).UpdateInvoiceAsync(mockSubscription.LatestInvoiceId,
Arg.Is<InvoiceUpdateOptions>(opts => opts.AutoAdvance == false));
await _braintreeService.Received(1).PayInvoice(Arg.Any<SubscriberId>(), mockInvoice);
}
[Theory, BitAutoData]

View File

@@ -18,13 +18,11 @@ public class UpdatePremiumStorageCommandTests
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly PremiumPlan _premiumPlan;
private readonly UpdatePremiumStorageCommand _command;
public UpdatePremiumStorageCommandTests()
{
// Setup default premium plan with standard pricing
_premiumPlan = new PremiumPlan
var premiumPlan = new PremiumPlan
{
Name = "Premium",
Available = true,
@@ -32,7 +30,7 @@ public class UpdatePremiumStorageCommandTests
Seat = new PremiumPurchasable { Price = 10M, StripePriceId = "price_premium", Provided = 1 },
Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "price_storage", Provided = 1 }
};
_pricingClient.ListPremiumPlans().Returns(new List<PremiumPlan> { _premiumPlan });
_pricingClient.ListPremiumPlans().Returns([premiumPlan]);
_command = new UpdatePremiumStorageCommand(
_stripeAdapter,
@@ -43,18 +41,19 @@ public class UpdatePremiumStorageCommandTests
private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null)
{
var items = new List<SubscriptionItem>();
// Always add the seat item
items.Add(new SubscriptionItem
var items = new List<SubscriptionItem>
{
Id = "si_seat",
Price = new Price { Id = "price_premium" },
Quantity = 1
});
// Always add the seat item
new()
{
Id = "si_seat",
Price = new Price { Id = "price_premium" },
Quantity = 1
}
};
// Add storage item if quantity is provided
if (storageQuantity.HasValue && storageQuantity.Value > 0)
if (storageQuantity is > 0)
{
items.Add(new SubscriptionItem
{
@@ -142,7 +141,7 @@ public class UpdatePremiumStorageCommandTests
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("No access to storage.", badRequest.Response);
Assert.Equal("User has no access to storage.", badRequest.Response);
}
[Theory, BitAutoData]
@@ -216,7 +215,7 @@ public class UpdatePremiumStorageCommandTests
opts.Items.Count == 1 &&
opts.Items[0].Id == "si_storage" &&
opts.Items[0].Quantity == 9 &&
opts.ProrationBehavior == "create_prorations"));
opts.ProrationBehavior == "always_invoice"));
// Verify user was saved
await _userService.Received(1).SaveUserAsync(Arg.Is<User>(u =>
@@ -233,7 +232,7 @@ public class UpdatePremiumStorageCommandTests
user.Storage = 500L * 1024 * 1024;
user.GatewaySubscriptionId = "sub_123";
var subscription = CreateMockSubscription("sub_123", null);
var subscription = CreateMockSubscription("sub_123");
_stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription);
// Act

View File

@@ -0,0 +1,607 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Subscriptions.Models;
using Bit.Core.Billing.Subscriptions.Queries;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Subscriptions.Queries;
using static StripeConstants;
public class GetBitwardenSubscriptionQueryTests
{
private readonly ILogger<GetBitwardenSubscriptionQuery> _logger = Substitute.For<ILogger<GetBitwardenSubscriptionQuery>>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly GetBitwardenSubscriptionQuery _query;
public GetBitwardenSubscriptionQueryTests()
{
_query = new GetBitwardenSubscriptionQuery(
_logger,
_pricingClient,
_stripeAdapter);
}
[Fact]
public async Task Run_IncompleteStatus_ReturnsBitwardenSubscriptionWithSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Incomplete);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Incomplete, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(subscription.Created.AddHours(23), result.Suspension);
Assert.Equal(1, result.GracePeriod);
Assert.Null(result.NextCharge);
Assert.Null(result.CancelAt);
}
[Fact]
public async Task Run_IncompleteExpiredStatus_ReturnsBitwardenSubscriptionWithSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.IncompleteExpired);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.IncompleteExpired, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(subscription.Created.AddHours(23), result.Suspension);
Assert.Equal(1, result.GracePeriod);
}
[Fact]
public async Task Run_TrialingStatus_ReturnsBitwardenSubscriptionWithNextCharge()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Trialing);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Trialing, result.Status);
Assert.NotNull(result.NextCharge);
Assert.Equal(subscription.Items.First().CurrentPeriodEnd, result.NextCharge);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_ActiveStatus_ReturnsBitwardenSubscriptionWithNextCharge()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Active, result.Status);
Assert.NotNull(result.NextCharge);
Assert.Equal(subscription.Items.First().CurrentPeriodEnd, result.NextCharge);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_ActiveStatusWithCancelAt_ReturnsCancelAt()
{
var user = CreateUser();
var cancelAt = DateTime.UtcNow.AddMonths(1);
var subscription = CreateSubscription(SubscriptionStatus.Active, cancelAt: cancelAt);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Active, result.Status);
Assert.Equal(cancelAt, result.CancelAt);
}
[Fact]
public async Task Run_PastDueStatus_WithOpenInvoices_ReturnsSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.PastDue, collectionMethod: "charge_automatically");
var premiumPlans = CreatePremiumPlans();
var openInvoice = CreateInvoice();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([openInvoice]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.PastDue, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(openInvoice.Created.AddDays(14), result.Suspension);
Assert.Equal(14, result.GracePeriod);
}
[Fact]
public async Task Run_PastDueStatus_WithoutOpenInvoices_ReturnsNoSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.PastDue);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.PastDue, result.Status);
Assert.Null(result.Suspension);
Assert.Null(result.GracePeriod);
}
[Fact]
public async Task Run_UnpaidStatus_WithOpenInvoices_ReturnsSuspension()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Unpaid, collectionMethod: "charge_automatically");
var premiumPlans = CreatePremiumPlans();
var openInvoice = CreateInvoice();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
_stripeAdapter.SearchInvoiceAsync(Arg.Any<InvoiceSearchOptions>())
.Returns([openInvoice]);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Unpaid, result.Status);
Assert.NotNull(result.Suspension);
Assert.Equal(14, result.GracePeriod);
}
[Fact]
public async Task Run_CanceledStatus_ReturnsCanceledDate()
{
var user = CreateUser();
var canceledAt = DateTime.UtcNow.AddDays(-5);
var subscription = CreateSubscription(SubscriptionStatus.Canceled, canceledAt: canceledAt);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(SubscriptionStatus.Canceled, result.Status);
Assert.Equal(canceledAt, result.Canceled);
Assert.Null(result.Suspension);
Assert.Null(result.NextCharge);
}
[Fact]
public async Task Run_UnmanagedStatus_ThrowsConflictException()
{
var user = CreateUser();
var subscription = CreateSubscription("unmanaged_status");
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
await Assert.ThrowsAsync<ConflictException>(() => _query.Run(user));
}
[Fact]
public async Task Run_WithAdditionalStorage_IncludesStorageInCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: true);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.PasswordManager.AdditionalStorage);
Assert.Equal("additionalStorageGB", result.Cart.PasswordManager.AdditionalStorage.TranslationKey);
Assert.Equal(2, result.Cart.PasswordManager.AdditionalStorage.Quantity);
Assert.NotNull(result.Storage);
}
[Fact]
public async Task Run_WithoutAdditionalStorage_ExcludesStorageFromCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: false);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Null(result.Cart.PasswordManager.AdditionalStorage);
Assert.NotNull(result.Storage);
}
[Fact]
public async Task Run_WithCartLevelDiscount_IncludesDiscountInCart()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
subscription.Customer.Discount = CreateDiscount(discountType: "cart");
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.Discount);
Assert.Equal(BitwardenDiscountType.PercentOff, result.Cart.Discount.Type);
Assert.Equal(20, result.Cart.Discount.Value);
}
[Fact]
public async Task Run_WithProductLevelDiscount_IncludesDiscountInCartItem()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var productDiscount = CreateDiscount(discountType: "product", productId: "prod_premium_seat");
subscription.Discounts = [productDiscount];
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Cart.PasswordManager.Seats.Discount);
Assert.Equal(BitwardenDiscountType.PercentOff, result.Cart.PasswordManager.Seats.Discount.Type);
}
[Fact]
public async Task Run_WithoutMaxStorageGb_ReturnsNullStorage()
{
var user = CreateUser();
user.MaxStorageGb = null;
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Null(result.Storage);
}
[Fact]
public async Task Run_CalculatesStorageCorrectly()
{
var user = CreateUser();
user.Storage = 5368709120; // 5 GB in bytes
user.MaxStorageGb = 10;
var subscription = CreateSubscription(SubscriptionStatus.Active, includeStorage: true);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.NotNull(result.Storage);
Assert.Equal(10, result.Storage.Available);
Assert.Equal(5.0, result.Storage.Used);
Assert.NotEmpty(result.Storage.ReadableUsed);
}
[Fact]
public async Task Run_TaxEstimation_WithInvoiceUpcomingNoneError_ReturnsZeroTax()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.ThrowsAsync(new StripeException { StripeError = new StripeError { Code = ErrorCodes.InvoiceUpcomingNone } });
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(0, result.Cart.EstimatedTax);
}
[Fact]
public async Task Run_MissingPasswordManagerSeatsItem_ThrowsConflictException()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
subscription.Items = new StripeList<SubscriptionItem>
{
Data = []
};
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
await Assert.ThrowsAsync<ConflictException>(() => _query.Run(user));
}
[Fact]
public async Task Run_IncludesEstimatedTax()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
var invoice = CreateInvoicePreview(totalTax: 500); // $5.00 tax
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(invoice);
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(5.0m, result.Cart.EstimatedTax);
}
[Fact]
public async Task Run_SetsCadenceToAnnually()
{
var user = CreateUser();
var subscription = CreateSubscription(SubscriptionStatus.Active);
var premiumPlans = CreatePremiumPlans();
_stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any<SubscriptionGetOptions>())
.Returns(subscription);
_pricingClient.ListPremiumPlans().Returns(premiumPlans);
_stripeAdapter.CreateInvoicePreviewAsync(Arg.Any<InvoiceCreatePreviewOptions>())
.Returns(CreateInvoicePreview());
var result = await _query.Run(user);
Assert.NotNull(result);
Assert.Equal(PlanCadenceType.Annually, result.Cart.Cadence);
}
#region Helper Methods
private static User CreateUser()
{
return new User
{
Id = Guid.NewGuid(),
GatewaySubscriptionId = "sub_test123",
MaxStorageGb = 1,
Storage = 1073741824 // 1 GB in bytes
};
}
private static Subscription CreateSubscription(
string status,
bool includeStorage = false,
DateTime? cancelAt = null,
DateTime? canceledAt = null,
string collectionMethod = "charge_automatically")
{
var currentPeriodEnd = DateTime.UtcNow.AddMonths(1);
var items = new List<SubscriptionItem>
{
new()
{
Id = "si_premium_seat",
Price = new Price
{
Id = "price_premium_seat",
UnitAmountDecimal = 1000,
Product = new Product { Id = "prod_premium_seat" }
},
Quantity = 1,
CurrentPeriodStart = DateTime.UtcNow,
CurrentPeriodEnd = currentPeriodEnd
}
};
if (includeStorage)
{
items.Add(new SubscriptionItem
{
Id = "si_storage",
Price = new Price
{
Id = "price_storage",
UnitAmountDecimal = 400,
Product = new Product { Id = "prod_storage" }
},
Quantity = 2,
CurrentPeriodStart = DateTime.UtcNow,
CurrentPeriodEnd = currentPeriodEnd
});
}
return new Subscription
{
Id = "sub_test123",
Status = status,
Created = DateTime.UtcNow.AddMonths(-1),
Customer = new Customer
{
Id = "cus_test123",
Discount = null
},
Items = new StripeList<SubscriptionItem>
{
Data = items
},
CancelAt = cancelAt,
CanceledAt = canceledAt,
CollectionMethod = collectionMethod,
Discounts = []
};
}
private static List<Bit.Core.Billing.Pricing.Premium.Plan> CreatePremiumPlans()
{
return
[
new()
{
Name = "Premium",
Available = true,
Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_premium_seat",
Price = 10.0m,
Provided = 1
},
Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable
{
StripePriceId = "price_storage",
Price = 4.0m,
Provided = 1
}
}
];
}
private static Invoice CreateInvoice()
{
return new Invoice
{
Id = "in_test123",
Created = DateTime.UtcNow.AddDays(-10),
PeriodEnd = DateTime.UtcNow.AddDays(-5),
Attempted = true,
Status = "open"
};
}
private static Invoice CreateInvoicePreview(long totalTax = 0)
{
var taxes = totalTax > 0
? new List<InvoiceTotalTax> { new() { Amount = totalTax } }
: new List<InvoiceTotalTax>();
return new Invoice
{
Id = "in_preview",
TotalTaxes = taxes
};
}
private static Discount CreateDiscount(string discountType = "cart", string? productId = null)
{
var coupon = new Coupon
{
Valid = true,
PercentOff = 20,
AppliesTo = discountType == "product" && productId != null
? new CouponAppliesTo { Products = [productId] }
: new CouponAppliesTo { Products = [] }
};
return new Discount
{
Coupon = coupon
};
}
#endregion
}

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<RootNamespace>Bit.Core.Test</RootNamespace>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
@@ -30,7 +32,7 @@
<ItemGroup>
<!-- Email templates uses .hbs extension, they must be included for emails to work -->
<EmbeddedResource Include="**\*.hbs" />
<EmbeddedResource Include="Utilities\data\embeddedResource.txt" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,211 @@
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class PlayIdServiceTests
{
[Theory]
[BitAutoData]
public void InPlay_WhenPlayIdSetAndDevelopment_ReturnsTrue(
string playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.True(result);
Assert.Equal(playId, resultPlayId);
}
[Theory]
[BitAutoData]
public void InPlay_WhenPlayIdSetButNotDevelopment_ReturnsFalse(
string playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Production);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.False(result);
Assert.Equal(playId, resultPlayId);
}
[Theory]
[BitAutoData((string?)null)]
[BitAutoData("")]
public void InPlay_WhenPlayIdNullOrEmptyAndDevelopment_ReturnsFalse(
string? playId,
SutProvider<PlayIdService> sutProvider)
{
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
sutProvider.Sut.PlayId = playId;
var result = sutProvider.Sut.InPlay(out var resultPlayId);
Assert.False(result);
Assert.Empty(resultPlayId);
}
[Theory]
[BitAutoData]
public void PlayId_CanGetAndSet(string playId)
{
var hostEnvironment = Substitute.For<IHostEnvironment>();
var sut = new PlayIdService(hostEnvironment);
sut.PlayId = playId;
Assert.Equal(playId, sut.PlayId);
}
}
[SutProviderCustomize]
public class NeverPlayIdServicesTests
{
[Fact]
public void InPlay_ReturnsFalse()
{
var sut = new NeverPlayIdServices();
var result = sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[InlineData("test-play-id")]
[InlineData(null)]
public void PlayId_SetterDoesNothing_GetterReturnsNull(string? value)
{
var sut = new NeverPlayIdServices();
sut.PlayId = value;
Assert.Null(sut.PlayId);
}
}
[SutProviderCustomize]
public class PlayIdSingletonServiceTests
{
public static IEnumerable<object[]> SutProvider()
{
var sutProvider = new SutProvider<PlayIdSingletonService>();
var httpContext = sutProvider.CreateDependency<HttpContext>();
var serviceProvider = sutProvider.CreateDependency<IServiceProvider>();
var hostEnvironment = sutProvider.CreateDependency<IHostEnvironment>();
var playIdService = new PlayIdService(hostEnvironment);
sutProvider.SetDependency(playIdService);
httpContext.RequestServices.Returns(serviceProvider);
serviceProvider.GetService<PlayIdService>().Returns(playIdService);
serviceProvider.GetRequiredService<PlayIdService>().Returns(playIdService);
sutProvider.CreateDependency<IHttpContextAccessor>().HttpContext.Returns(httpContext);
sutProvider.Create();
return [[sutProvider]];
}
private void PrepHttpContext(
SutProvider<PlayIdSingletonService> sutProvider)
{
var httpContext = sutProvider.CreateDependency<HttpContext>();
var serviceProvider = sutProvider.CreateDependency<IServiceProvider>();
var PlayIdService = sutProvider.CreateDependency<PlayIdService>();
httpContext.RequestServices.Returns(serviceProvider);
serviceProvider.GetRequiredService<PlayIdService>().Returns(PlayIdService);
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns(httpContext);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenNoHttpContext_ReturnsFalse(
SutProvider<PlayIdSingletonService> sutProvider)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenNotDevelopment_ReturnsFalse(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
var scopedPlayIdService = sutProvider.GetDependency<PlayIdService>();
scopedPlayIdService.PlayId = playIdValue;
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Production);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.False(result);
Assert.Empty(playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void InPlay_WhenDevelopmentAndHttpContextWithPlayId_ReturnsTrue(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
sutProvider.GetDependency<PlayIdService>().PlayId = playIdValue;
sutProvider.GetDependency<IHostEnvironment>().EnvironmentName.Returns(Environments.Development);
var result = sutProvider.Sut.InPlay(out var playId);
Assert.True(result);
Assert.Equal(playIdValue, playId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_SetterSetsOnScopedService(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
var scopedPlayIdService = sutProvider.GetDependency<PlayIdService>();
sutProvider.Sut.PlayId = playIdValue;
Assert.Equal(playIdValue, scopedPlayIdService.PlayId);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_WhenNoHttpContext_GetterReturnsNull(
SutProvider<PlayIdSingletonService> sutProvider)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
var result = sutProvider.Sut.PlayId;
Assert.Null(result);
}
[Theory]
[BitMemberAutoData(nameof(SutProvider))]
public void PlayId_WhenNoHttpContext_SetterDoesNotThrow(
SutProvider<PlayIdSingletonService> sutProvider,
string playIdValue)
{
sutProvider.GetDependency<IHttpContextAccessor>().HttpContext.Returns((HttpContext?)null);
sutProvider.Sut.PlayId = playIdValue;
}
}

View File

@@ -0,0 +1,143 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Services;
[SutProviderCustomize]
public class PlayItemServiceTests
{
[Theory]
[BitAutoData]
public async Task Record_User_WhenInPlay_RecordsPlayItem(
string playId,
User user,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = playId;
return true;
});
await sutProvider.Sut.Record(user);
await sutProvider.GetDependency<IPlayItemRepository>()
.Received(1)
.CreateAsync(Arg.Is<PlayItem>(pd =>
pd.PlayId == playId &&
pd.UserId == user.Id &&
pd.OrganizationId == null));
sutProvider.GetDependency<ILogger<PlayItemService>>()
.Received(1)
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o => o.ToString().Contains(user.Id.ToString()) && o.ToString().Contains(playId)),
null,
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_User_WhenNotInPlay_DoesNotRecordPlayItem(
User user,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = null;
return false;
});
await sutProvider.Sut.Record(user);
await sutProvider.GetDependency<IPlayItemRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<PlayItem>());
sutProvider.GetDependency<ILogger<PlayItemService>>()
.DidNotReceive()
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_Organization_WhenInPlay_RecordsPlayItem(
string playId,
Organization organization,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = playId;
return true;
});
await sutProvider.Sut.Record(organization);
await sutProvider.GetDependency<IPlayItemRepository>()
.Received(1)
.CreateAsync(Arg.Is<PlayItem>(pd =>
pd.PlayId == playId &&
pd.OrganizationId == organization.Id &&
pd.UserId == null));
sutProvider.GetDependency<ILogger<PlayItemService>>()
.Received(1)
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Is<object>(o => o.ToString().Contains(organization.Id.ToString()) && o.ToString().Contains(playId)),
null,
Arg.Any<Func<object, Exception?, string>>());
}
[Theory]
[BitAutoData]
public async Task Record_Organization_WhenNotInPlay_DoesNotRecordPlayItem(
Organization organization,
SutProvider<PlayItemService> sutProvider)
{
sutProvider.GetDependency<IPlayIdService>()
.InPlay(out Arg.Any<string>())
.Returns(x =>
{
x[0] = null;
return false;
});
await sutProvider.Sut.Record(organization);
await sutProvider.GetDependency<IPlayItemRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<PlayItem>());
sutProvider.GetDependency<ILogger<PlayItemService>>()
.DidNotReceive()
.Log(
LogLevel.Information,
Arg.Any<EventId>(),
Arg.Any<object>(),
Arg.Any<Exception>(),
Arg.Any<Func<object, Exception?, string>>());
}
}

View File

@@ -0,0 +1,219 @@
using System.Runtime.Serialization;
using System.Text.Json;
using System.Text.Json.Serialization;
using Bit.Core.Utilities;
using Xunit;
namespace Bit.Core.Test.Utilities;
public class EnumMemberJsonConverterTests
{
[Fact]
public void Serialize_WithEnumMemberAttribute_UsesAttributeValue()
{
// Arrange
var obj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.InProgress
};
const string expectedJsonString = "{\"Status\":\"in_progress\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Serialize_WithoutEnumMemberAttribute_UsesEnumName()
{
// Arrange
var obj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Pending
};
const string expectedJsonString = "{\"Status\":\"Pending\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Serialize_MultipleValues_SerializesCorrectly()
{
// Arrange
var obj = new EnumConverterTestObjectWithMultiple
{
Status1 = EnumConverterTestStatus.Active,
Status2 = EnumConverterTestStatus.InProgress,
Status3 = EnumConverterTestStatus.Pending
};
const string expectedJsonString = "{\"Status1\":\"active\",\"Status2\":\"in_progress\",\"Status3\":\"Pending\"}";
// Act
var jsonString = JsonSerializer.Serialize(obj);
// Assert
Assert.Equal(expectedJsonString, jsonString);
}
[Fact]
public void Deserialize_WithEnumMemberAttribute_ReturnsCorrectEnumValue()
{
// Arrange
const string json = "{\"Status\":\"in_progress\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.InProgress, obj.Status);
}
[Fact]
public void Deserialize_WithoutEnumMemberAttribute_ReturnsCorrectEnumValue()
{
// Arrange
const string json = "{\"Status\":\"Pending\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.Pending, obj.Status);
}
[Fact]
public void Deserialize_MultipleValues_DeserializesCorrectly()
{
// Arrange
const string json = "{\"Status1\":\"active\",\"Status2\":\"in_progress\",\"Status3\":\"Pending\"}";
// Act
var obj = JsonSerializer.Deserialize<EnumConverterTestObjectWithMultiple>(json);
// Assert
Assert.Equal(EnumConverterTestStatus.Active, obj.Status1);
Assert.Equal(EnumConverterTestStatus.InProgress, obj.Status2);
Assert.Equal(EnumConverterTestStatus.Pending, obj.Status3);
}
[Fact]
public void Deserialize_InvalidEnumString_ThrowsJsonException()
{
// Arrange
const string json = "{\"Status\":\"invalid_value\"}";
// Act & Assert
var exception = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<EnumConverterTestObject>(json));
Assert.Contains("Unable to convert 'invalid_value' to EnumConverterTestStatus", exception.Message);
}
[Fact]
public void Deserialize_EmptyString_ThrowsJsonException()
{
// Arrange
const string json = "{\"Status\":\"\"}";
// Act & Assert
var exception = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<EnumConverterTestObject>(json));
Assert.Contains("Unable to convert '' to EnumConverterTestStatus", exception.Message);
}
[Fact]
public void RoundTrip_WithEnumMemberAttribute_PreservesValue()
{
// Arrange
var originalObj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Completed
};
// Act
var json = JsonSerializer.Serialize(originalObj);
var deserializedObj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(originalObj.Status, deserializedObj.Status);
}
[Fact]
public void RoundTrip_WithoutEnumMemberAttribute_PreservesValue()
{
// Arrange
var originalObj = new EnumConverterTestObject
{
Status = EnumConverterTestStatus.Pending
};
// Act
var json = JsonSerializer.Serialize(originalObj);
var deserializedObj = JsonSerializer.Deserialize<EnumConverterTestObject>(json);
// Assert
Assert.Equal(originalObj.Status, deserializedObj.Status);
}
[Fact]
public void Serialize_AllEnumValues_ProducesExpectedStrings()
{
// Arrange & Act & Assert
Assert.Equal("\"Pending\"", JsonSerializer.Serialize(EnumConverterTestStatus.Pending, CreateOptions()));
Assert.Equal("\"active\"", JsonSerializer.Serialize(EnumConverterTestStatus.Active, CreateOptions()));
Assert.Equal("\"in_progress\"", JsonSerializer.Serialize(EnumConverterTestStatus.InProgress, CreateOptions()));
Assert.Equal("\"completed\"", JsonSerializer.Serialize(EnumConverterTestStatus.Completed, CreateOptions()));
}
[Fact]
public void Deserialize_AllEnumValues_ReturnsCorrectEnums()
{
// Arrange & Act & Assert
Assert.Equal(EnumConverterTestStatus.Pending, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"Pending\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.Active, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"active\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.InProgress, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"in_progress\"", CreateOptions()));
Assert.Equal(EnumConverterTestStatus.Completed, JsonSerializer.Deserialize<EnumConverterTestStatus>("\"completed\"", CreateOptions()));
}
private static JsonSerializerOptions CreateOptions()
{
var options = new JsonSerializerOptions();
options.Converters.Add(new EnumMemberJsonConverter<EnumConverterTestStatus>());
return options;
}
}
public class EnumConverterTestObject
{
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status { get; set; }
}
public class EnumConverterTestObjectWithMultiple
{
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status1 { get; set; }
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status2 { get; set; }
[JsonConverter(typeof(EnumMemberJsonConverter<EnumConverterTestStatus>))]
public EnumConverterTestStatus Status3 { get; set; }
}
public enum EnumConverterTestStatus
{
Pending, // No EnumMemberAttribute
[EnumMember(Value = "active")]
Active,
[EnumMember(Value = "in_progress")]
InProgress,
[EnumMember(Value = "completed")]
Completed
}

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1304;CA1305</WarningsNotAsErrors>
</PropertyGroup>
<PropertyGroup Condition=" '$(RunConfiguration)' == 'Identity.IntegrationTest' " />

View File

@@ -2,6 +2,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -1,5 +1,4 @@
using Bit.Core;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Context;
@@ -7,7 +6,6 @@ using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Identity.IdentityServer;
using Bit.Identity.Test.AutoFixture;
using Bit.Identity.Utilities;
@@ -25,7 +23,6 @@ public class UserDecryptionOptionsBuilderTests
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
private readonly UserDecryptionOptionsBuilder _builder;
private readonly IFeatureService _featureService;
public UserDecryptionOptionsBuilderTests()
{
@@ -33,8 +30,7 @@ public class UserDecryptionOptionsBuilderTests
_deviceRepository = Substitute.For<IDeviceRepository>();
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_loginApprovingClientTypes = Substitute.For<ILoginApprovingClientTypes>();
_featureService = Substitute.For<IFeatureService>();
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes, _featureService);
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes);
var user = new User();
_builder.ForUser(user);
}
@@ -227,43 +223,6 @@ public class UserDecryptionOptionsBuilderTests
Assert.False(result.TrustedDeviceOption?.HasLoginApprovingDevice);
}
/// <summary>
/// This logic has been flagged as part of PM-23174.
/// When removing the server flag, please also remove this test, and remove the FeatureService
/// dependency from this suite and the following test.
/// </summary>
/// <param name="organizationUserType"></param>
/// <param name="ssoConfig"></param>
/// <param name="configurationData"></param>
/// <param name="organization"></param>
/// <param name="organizationUser"></param>
/// <param name="user"></param>
[Theory]
[BitAutoData(OrganizationUserType.Custom)]
public async Task Build_WhenManageResetPasswordPermissions_ShouldReturnHasManageResetPasswordPermissionTrue(
OrganizationUserType organizationUserType,
SsoConfig ssoConfig,
SsoConfigurationData configurationData,
CurrentContextOrganization organization,
[OrganizationUserWithDefaultPermissions] OrganizationUser organizationUser,
User user)
{
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize();
ssoConfig.OrganizationId = organization.Id;
_currentContext.Organizations.Returns([organization]);
_currentContext.ManageResetPassword(organization.Id).Returns(true);
organizationUser.Type = organizationUserType;
organizationUser.OrganizationId = organization.Id;
organizationUser.UserId = user.Id;
organizationUser.SetPermissions(new Permissions() { ManageResetPassword = true });
_organizationUserRepository.GetByOrganizationAsync(ssoConfig.OrganizationId, user.Id).Returns(organizationUser);
var result = await _builder.ForUser(user).WithSso(ssoConfig).BuildAsync();
Assert.True(result.TrustedDeviceOption?.HasManageResetPasswordPermission);
}
[Theory]
[BitAutoData(OrganizationUserType.Custom)]
public async Task Build_WhenManageResetPasswordPermissions_ShouldFetchUserFromRepositoryAndReturnHasManageResetPasswordPermissionTrue(
@@ -274,8 +233,6 @@ public class UserDecryptionOptionsBuilderTests
[OrganizationUserWithDefaultPermissions] OrganizationUser organizationUser,
User user)
{
_featureService.IsEnabled(FeatureFlagKeys.PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword)
.Returns(true);
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize();
ssoConfig.OrganizationId = organization.Id;

View File

@@ -1,6 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">

View File

@@ -128,7 +128,6 @@ public class DatabaseDataAttribute : DataAttribute
private void AddDapperServices(IServiceCollection services, Database database)
{
services.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
@@ -141,6 +140,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
}
};
services.AddDapperRepositories(SelfHosted);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);
services.AddSingleton(database);
@@ -160,7 +160,6 @@ public class DatabaseDataAttribute : DataAttribute
private void AddEfServices(IServiceCollection services, Database database)
{
services.SetupEntityFramework(database.ConnectionString, database.Type);
services.AddPasswordManagerEFRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
@@ -169,6 +168,7 @@ public class DatabaseDataAttribute : DataAttribute
UserRequestExpiration = TimeSpan.FromMinutes(15),
},
};
services.AddPasswordManagerEFRepositories(SelfHosted);
services.AddSingleton(globalSettings);
services.AddSingleton<IGlobalSettings>(globalSettings);

View File

@@ -3,6 +3,8 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<UserSecretsId>6570f288-5c2c-47ad-8978-f3da255079c2</UserSecretsId>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>

View File

@@ -189,7 +189,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
/// Registers a new user to the Identity Application Factory based on the RegisterFinishRequestModel
/// </summary>
/// <param name="requestModel">RegisterFinishRequestModel needed to seed data to the test user</param>
/// <param name="marketingEmails">optional parameter that is tracked during the inital steps of registration.</param>
/// <param name="marketingEmails">optional parameter that is tracked during the initial steps of registration.</param>
/// <returns>returns the newly created user</returns>
public async Task<User> RegisterNewIdentityFactoryUserAsync(
RegisterFinishRequestModel requestModel,

View File

@@ -47,7 +47,7 @@ public abstract class WebApplicationFactoryBase<T> : WebApplicationFactory<T>
/// </remarks>
public bool ManagesDatabase { get; set; } = true;
private readonly List<Action<IServiceCollection>> _configureTestServices = new();
protected readonly List<Action<IServiceCollection>> _configureTestServices = new();
private readonly List<Action<IConfigurationBuilder>> _configureAppConfiguration = new();
public void SubstituteService<TService>(Action<TService> mockService)

View File

@@ -2,13 +2,15 @@
<PropertyGroup>
<IsPackable>false</IsPackable>
<!-- These opt outs should be removed when all warnings are addressed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CA1305</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Mvc.Testing" Version="8.0.10" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Identity\Identity.csproj" />
<ProjectReference Include="..\..\util\Migrator\Migrator.csproj" />

View File

@@ -0,0 +1,40 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
namespace Bit.SeederApi.IntegrationTest;
public static class HttpClientExtensions
{
/// <summary>
/// Sends a POST request with JSON content and attaches the x-play-id header.
/// </summary>
/// <typeparam name="TValue">The type of the value to serialize.</typeparam>
/// <param name="client">The HTTP client.</param>
/// <param name="requestUri">The URI the request is sent to.</param>
/// <param name="value">The value to serialize.</param>
/// <param name="playId">The play ID to attach as x-play-id header.</param>
/// <param name="options">Options to control the behavior during serialization.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
/// <returns>The task object representing the asynchronous operation.</returns>
public static Task<HttpResponseMessage> PostAsJsonAsync<TValue>(
this HttpClient client,
[StringSyntax(StringSyntaxAttribute.Uri)] string? requestUri,
TValue value,
string playId,
JsonSerializerOptions? options = null,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(client);
if (string.IsNullOrWhiteSpace(playId))
{
throw new ArgumentException("Play ID cannot be null or whitespace.", nameof(playId));
}
var content = JsonContent.Create(value, mediaType: null, options);
content.Headers.Remove("x-play-id");
content.Headers.Add("x-play-id", playId);
return client.PostAsync(requestUri, content, cancellationToken);
}
}

View File

@@ -0,0 +1,75 @@
using System.Net;
using Bit.SeederApi.Models.Request;
using Xunit;
namespace Bit.SeederApi.IntegrationTest;
public class QueryControllerTests : IClassFixture<SeederApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly SeederApiApplicationFactory _factory;
public QueryControllerTests(SeederApiApplicationFactory factory)
{
_factory = factory;
_client = _factory.CreateClient();
}
public Task InitializeAsync()
{
return Task.CompletedTask;
}
public Task DisposeAsync()
{
_client.Dispose();
return Task.CompletedTask;
}
[Fact]
public async Task QueryEndpoint_WithValidQueryAndArguments_ReturnsOk()
{
var testEmail = $"emergency-test-{Guid.NewGuid()}@bitwarden.com";
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "EmergencyAccessInviteQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
});
response.EnsureSuccessStatusCode();
var result = await response.Content.ReadAsStringAsync();
Assert.NotNull(result);
var urls = System.Text.Json.JsonSerializer.Deserialize<List<string>>(result);
Assert.NotNull(urls);
// For a non-existent email, we expect an empty list
Assert.Empty(urls);
}
[Fact]
public async Task QueryEndpoint_WithInvalidQueryName_ReturnsNotFound()
{
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "NonExistentQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = "test@example.com" })
});
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Fact]
public async Task QueryEndpoint_WithMissingRequiredField_ReturnsBadRequest()
{
// EmergencyAccessInviteQuery requires 'email' field
var response = await _client.PostAsJsonAsync("/query", new QueryRequestModel
{
Template = "EmergencyAccessInviteQuery",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { wrongField = "value" })
});
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
}

View File

@@ -0,0 +1,222 @@
using System.Net;
using Bit.SeederApi.Models.Request;
using Bit.SeederApi.Models.Response;
using Xunit;
namespace Bit.SeederApi.IntegrationTest;
public class SeedControllerTests : IClassFixture<SeederApiApplicationFactory>, IAsyncLifetime
{
private readonly HttpClient _client;
private readonly SeederApiApplicationFactory _factory;
public SeedControllerTests(SeederApiApplicationFactory factory)
{
_factory = factory;
_client = _factory.CreateClient();
}
public Task InitializeAsync()
{
return Task.CompletedTask;
}
public async Task DisposeAsync()
{
// Clean up any seeded data after each test
await _client.DeleteAsync("/seed");
_client.Dispose();
}
[Fact]
public async Task SeedEndpoint_WithValidScene_ReturnsOk()
{
var testEmail = $"seed-test-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
response.EnsureSuccessStatusCode();
var result = await response.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(result);
Assert.NotNull(result.MangleMap);
Assert.Null(result.Result);
}
[Fact]
public async Task SeedEndpoint_WithInvalidSceneName_ReturnsNotFound()
{
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "NonExistentScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = "test@example.com" })
});
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Fact]
public async Task SeedEndpoint_WithMissingRequiredField_ReturnsBadRequest()
{
// SingleUserScene requires 'email' field
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { wrongField = "value" })
});
Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode);
}
[Fact]
public async Task DeleteEndpoint_WithValidPlayId_ReturnsOk()
{
var testEmail = $"delete-test-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
var deleteResponse = await _client.DeleteAsync($"/seed/{playId}");
deleteResponse.EnsureSuccessStatusCode();
}
[Fact]
public async Task DeleteEndpoint_WithInvalidPlayId_ReturnsOk()
{
// DestroyRecipe is idempotent - returns null for non-existent play IDs
var nonExistentPlayId = Guid.NewGuid().ToString();
var response = await _client.DeleteAsync($"/seed/{nonExistentPlayId}");
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
Assert.Equal($$"""{"playId":"{{nonExistentPlayId}}"}""", content);
}
[Fact]
public async Task DeleteBatchEndpoint_WithValidPlayIds_ReturnsOk()
{
// Create multiple seeds with different play IDs
var playIds = new List<string>();
for (var i = 0; i < 3; i++)
{
var playId = Guid.NewGuid().ToString();
playIds.Add(playId);
var testEmail = $"batch-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
}
// Delete them in batch
var request = new HttpRequestMessage(HttpMethod.Delete, "/seed/batch")
{
Content = JsonContent.Create(playIds)
};
var deleteResponse = await _client.SendAsync(request);
deleteResponse.EnsureSuccessStatusCode();
var result = await deleteResponse.Content.ReadFromJsonAsync<BatchDeleteResponse>();
Assert.NotNull(result);
Assert.Equal("Batch delete completed successfully", result.Message);
}
[Fact]
public async Task DeleteBatchEndpoint_WithSomeInvalidIds_ReturnsOk()
{
// DestroyRecipe is idempotent - batch delete succeeds even with non-existent IDs
// Create one valid seed with a play ID
var validPlayId = Guid.NewGuid().ToString();
var testEmail = $"batch-partial-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, validPlayId);
seedResponse.EnsureSuccessStatusCode();
var seedResult = await seedResponse.Content.ReadFromJsonAsync<SceneResponseModel>();
Assert.NotNull(seedResult);
// Try to delete with mix of valid and invalid IDs
var playIds = new List<string> { validPlayId, Guid.NewGuid().ToString(), Guid.NewGuid().ToString() };
var request = new HttpRequestMessage(HttpMethod.Delete, "/seed/batch")
{
Content = JsonContent.Create(playIds)
};
var deleteResponse = await _client.SendAsync(request);
deleteResponse.EnsureSuccessStatusCode();
var result = await deleteResponse.Content.ReadFromJsonAsync<BatchDeleteResponse>();
Assert.NotNull(result);
Assert.Equal("Batch delete completed successfully", result.Message);
}
[Fact]
public async Task DeleteAllEndpoint_DeletesAllSeededData()
{
// Create multiple seeds
for (var i = 0; i < 2; i++)
{
var playId = Guid.NewGuid().ToString();
var testEmail = $"deleteall-test-{Guid.NewGuid()}@bitwarden.com";
var seedResponse = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
seedResponse.EnsureSuccessStatusCode();
}
// Delete all
var deleteResponse = await _client.DeleteAsync("/seed");
Assert.Equal(HttpStatusCode.NoContent, deleteResponse.StatusCode);
}
[Fact]
public async Task SeedEndpoint_VerifyResponseContainsMangleMapAndResult()
{
var testEmail = $"verify-response-{Guid.NewGuid()}@bitwarden.com";
var playId = Guid.NewGuid().ToString();
var response = await _client.PostAsJsonAsync("/seed", new SeedRequestModel
{
Template = "SingleUserScene",
Arguments = System.Text.Json.JsonSerializer.SerializeToElement(new { email = testEmail })
}, playId);
response.EnsureSuccessStatusCode();
var jsonString = await response.Content.ReadAsStringAsync();
// Verify the response contains MangleMap and Result fields
Assert.Contains("mangleMap", jsonString, StringComparison.OrdinalIgnoreCase);
Assert.Contains("result", jsonString, StringComparison.OrdinalIgnoreCase);
}
private class BatchDeleteResponse
{
public string? Message { get; set; }
}
}

View File

@@ -0,0 +1,29 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="$(CoverletCollectorVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\util\SeederApi\SeederApi.csproj" />
<ProjectReference Include="..\..\util\Seeder\Seeder.csproj" />
<ProjectReference Include="..\IntegrationTestCommon\IntegrationTestCommon.csproj" />
<Content Include="..\..\util\SeederApi\appsettings.*.json">
<Link>%(RecursiveDir)%(Filename)%(Extension)</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,18 @@
using Bit.Core.Services;
using Bit.IntegrationTestCommon;
using Bit.IntegrationTestCommon.Factories;
namespace Bit.SeederApi.IntegrationTest;
public class SeederApiApplicationFactory : WebApplicationFactoryBase<Startup>
{
public SeederApiApplicationFactory()
{
TestDatabase = new SqliteTestDatabase();
_configureTestServices.Add(serviceCollection =>
{
serviceCollection.AddSingleton<IPlayIdService, NeverPlayIdServices>();
serviceCollection.AddHttpContextAccessor();
});
}
}

View File

@@ -0,0 +1,102 @@
using Bit.Core.Services;
using Bit.SharedWeb.Utilities;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Hosting;
using NSubstitute;
namespace SharedWeb.Test;
public class PlayIdMiddlewareTests
{
private readonly PlayIdService _playIdService;
private readonly RequestDelegate _next;
private readonly PlayIdMiddleware _middleware;
public PlayIdMiddlewareTests()
{
var hostEnvironment = Substitute.For<IHostEnvironment>();
hostEnvironment.EnvironmentName.Returns(Environments.Development);
_playIdService = new PlayIdService(hostEnvironment);
_next = Substitute.For<RequestDelegate>();
_middleware = new PlayIdMiddleware(_next);
}
[Fact]
public async Task Invoke_WithValidPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play-id";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play-id", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithoutPlayIdHeader_CallsNext()
{
var context = new DefaultHttpContext();
await _middleware.Invoke(context, _playIdService);
Assert.Null(_playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Theory]
[InlineData("")]
[InlineData(" ")]
[InlineData("\t")]
public async Task Invoke_WithEmptyOrWhitespacePlayId_Returns400(string playId)
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Request.Headers["x-play-id"] = playId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdExceedingMaxLength_Returns400()
{
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
var longPlayId = new string('a', 257); // Exceeds 256 character limit
context.Request.Headers["x-play-id"] = longPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await _next.DidNotReceive().Invoke(context);
}
[Fact]
public async Task Invoke_WithPlayIdAtMaxLength_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
var maxLengthPlayId = new string('a', 256); // Exactly 256 characters
context.Request.Headers["x-play-id"] = maxLengthPlayId;
await _middleware.Invoke(context, _playIdService);
Assert.Equal(maxLengthPlayId, _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
[Fact]
public async Task Invoke_WithSpecialCharactersInPlayId_SetsPlayIdAndCallsNext()
{
var context = new DefaultHttpContext();
context.Request.Headers["x-play-id"] = "test-play_id.123";
await _middleware.Invoke(context, _playIdService);
Assert.Equal("test-play_id.123", _playIdService.PlayId);
await _next.Received(1).Invoke(context);
}
}