1
0
mirror of https://github.com/bitwarden/server synced 2025-12-06 00:03:34 +00:00

[PM-25088] - refactor premium purchase endpoint (#6262)

* [PM-25088] add feature flag for new premium subscription flow

* [PM-25088] refactor premium endpoint

* forgot the punctuation change in the test

* [PM-25088] - pr feedback

* [PM-25088] - pr feedback round two
This commit is contained in:
Kyle Denney
2025-09-10 10:08:22 -05:00
committed by GitHub
parent d43b00dad9
commit a458db319e
25 changed files with 1309 additions and 21 deletions

View File

@@ -0,0 +1,13 @@
using Bit.Api.Utilities;
namespace Bit.Api.Billing.Attributes;
public class PaymentMethodTypeValidationAttribute : StringMatchesAttribute
{
private static readonly string[] _acceptedValues = ["bankAccount", "card", "payPal"];
public PaymentMethodTypeValidationAttribute() : base(_acceptedValues)
{
ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}";
}
}

View File

@@ -1,8 +1,11 @@
#nullable enable
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Api.Billing.Models.Requests.Premium;
using Bit.Core;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Queries;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Entities;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@@ -16,6 +19,7 @@ namespace Bit.Api.Billing.Controllers.VNext;
[SelfHosted(NotSelfHostedOnly = true)]
public class AccountBillingVNextController(
ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand,
ICreatePremiumCloudHostedSubscriptionCommand createPremiumCloudHostedSubscriptionCommand,
IGetCreditQuery getCreditQuery,
IGetPaymentMethodQuery getPaymentMethodQuery,
IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController
@@ -61,4 +65,17 @@ public class AccountBillingVNextController(
var result = await updatePaymentMethodCommand.Run(user, paymentMethod, billingAddress);
return Handle(result);
}
[HttpPost("subscription")]
[RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)]
[InjectUser]
public async Task<IResult> CreateSubscriptionAsync(
[BindNever] User user,
[FromBody] PremiumCloudHostedSubscriptionRequest request)
{
var (paymentMethod, billingAddress, additionalStorageGb) = request.ToDomain();
var result = await createPremiumCloudHostedSubscriptionCommand.Run(
user, paymentMethod, billingAddress, additionalStorageGb);
return Handle(result);
}
}

View File

@@ -0,0 +1,38 @@
#nullable enable
using Bit.Api.Billing.Attributes;
using Bit.Api.Billing.Models.Requests.Premium;
using Bit.Api.Utilities;
using Bit.Core;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.ModelBinding;
namespace Bit.Api.Billing.Controllers.VNext;
[Authorize("Application")]
[Route("account/billing/vnext/self-host")]
[SelfHosted(SelfHostedOnly = true)]
public class SelfHostedAccountBillingController(
ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController
{
[HttpPost("license")]
[RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)]
[InjectUser]
public async Task<IResult> UploadLicenseAsync(
[BindNever] User user,
PremiumSelfHostedSubscriptionRequest request)
{
var license = await ApiHelpers.ReadJsonFileFromBody<UserLicense>(HttpContext, request.License);
if (license == null)
{
throw new BadRequestException("Invalid license.");
}
var result = await createPremiumSelfHostedSubscriptionCommand.Run(user, license);
return Handle(result);
}
}

View File

@@ -0,0 +1,25 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;
public class MinimalTokenizedPaymentMethodRequest
{
[Required]
[PaymentMethodTypeValidation]
public required string Type { get; set; }
[Required]
public required string Token { get; set; }
public TokenizedPaymentMethod ToDomain()
{
return new TokenizedPaymentMethod
{
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};
}
}

View File

@@ -1,6 +1,6 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Api.Utilities;
using Bit.Api.Billing.Attributes;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Payment;
@@ -8,8 +8,7 @@ namespace Bit.Api.Billing.Models.Requests.Payment;
public class TokenizedPaymentMethodRequest
{
[Required]
[StringMatches("bankAccount", "card", "payPal",
ErrorMessage = "Payment method type must be one of: bankAccount, card, payPal")]
[PaymentMethodTypeValidation]
public required string Type { get; set; }
[Required]
@@ -21,14 +20,7 @@ public class TokenizedPaymentMethodRequest
{
var paymentMethod = new TokenizedPaymentMethod
{
Type = Type switch
{
"bankAccount" => TokenizablePaymentMethodType.BankAccount,
"card" => TokenizablePaymentMethodType.Card,
"payPal" => TokenizablePaymentMethodType.PayPal,
_ => throw new InvalidOperationException(
$"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}")
},
Type = TokenizablePaymentMethodTypeExtensions.From(Type),
Token = Token
};

View File

@@ -0,0 +1,26 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Api.Billing.Models.Requests.Payment;
using Bit.Core.Billing.Payment.Models;
namespace Bit.Api.Billing.Models.Requests.Premium;
public class PremiumCloudHostedSubscriptionRequest
{
[Required]
public required MinimalTokenizedPaymentMethodRequest TokenizedPaymentMethod { get; set; }
[Required]
public required MinimalBillingAddressRequest BillingAddress { get; set; }
[Range(0, 99)]
public short AdditionalStorageGb { get; set; } = 0;
public (TokenizedPaymentMethod, BillingAddress, short) ToDomain()
{
var paymentMethod = TokenizedPaymentMethod.ToDomain();
var billingAddress = BillingAddress.ToDomain();
return (paymentMethod, billingAddress, AdditionalStorageGb);
}
}

View File

@@ -0,0 +1,10 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests.Premium;
public class PremiumSelfHostedSubscriptionRequest
{
[Required]
public required IFormFile License { get; set; }
}

View File

@@ -23,7 +23,7 @@ public static class OrganizationFactory
PlanType = claimsPrincipal.GetValue<PlanType>(OrganizationLicenseConstants.PlanType),
Seats = claimsPrincipal.GetValue<int?>(OrganizationLicenseConstants.Seats),
MaxCollections = claimsPrincipal.GetValue<short?>(OrganizationLicenseConstants.MaxCollections),
MaxStorageGb = 10240,
MaxStorageGb = Constants.SelfHostedMaxStorageGb,
UsePolicies = claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UsePolicies),
UseSso = claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseSso),
UseKeyConnector = claimsPrincipal.GetValue<bool>(OrganizationLicenseConstants.UseKeyConnector),
@@ -75,7 +75,7 @@ public static class OrganizationFactory
PlanType = license.PlanType,
Seats = license.Seats,
MaxCollections = license.MaxCollections,
MaxStorageGb = 10240,
MaxStorageGb = Constants.SelfHostedMaxStorageGb,
UsePolicies = license.UsePolicies,
UseSso = license.UseSso,
UseKeyConnector = license.UseKeyConnector,

View File

@@ -79,6 +79,7 @@ public static class StripeConstants
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
public const string PremiumAnnually = "premium-annually";
}
public static class ProrationBehavior

View File

@@ -5,6 +5,7 @@ using Bit.Core.Billing.Organizations.Commands;
using Bit.Core.Billing.Organizations.Queries;
using Bit.Core.Billing.Organizations.Services;
using Bit.Core.Billing.Payment;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
@@ -30,6 +31,7 @@ public static class ServiceCollectionExtensions
services.AddTransient<IPreviewTaxAmountCommand, PreviewTaxAmountCommand>();
services.AddPaymentOperations();
services.AddOrganizationLicenseCommandsQueries();
services.AddPremiumCommands();
services.AddTransient<IGetOrganizationWarningsQuery, GetOrganizationWarningsQuery>();
}
@@ -39,4 +41,10 @@ public static class ServiceCollectionExtensions
services.AddScoped<IGetSelfHostedOrganizationLicenseQuery, GetSelfHostedOrganizationLicenseQuery>();
services.AddScoped<IUpdateOrganizationLicenseCommand, UpdateOrganizationLicenseCommand>();
}
private static void AddPremiumCommands(this IServiceCollection services)
{
services.AddScoped<ICreatePremiumCloudHostedSubscriptionCommand, CreatePremiumCloudHostedSubscriptionCommand>();
services.AddScoped<ICreatePremiumSelfHostedSubscriptionCommand, CreatePremiumSelfHostedSubscriptionCommand>();
}
}

View File

@@ -6,3 +6,17 @@ public enum TokenizablePaymentMethodType
Card,
PayPal
}
public static class TokenizablePaymentMethodTypeExtensions
{
public static TokenizablePaymentMethodType From(string type)
{
return type switch
{
"bankAccount" => TokenizablePaymentMethodType.BankAccount,
"card" => TokenizablePaymentMethodType.Card,
"payPal" => TokenizablePaymentMethodType.PayPal,
_ => throw new InvalidOperationException($"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}")
};
}
}

View File

@@ -0,0 +1,308 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Braintree;
using Microsoft.Extensions.Logging;
using OneOf.Types;
using Stripe;
using Customer = Stripe.Customer;
using Subscription = Stripe.Subscription;
namespace Bit.Core.Billing.Premium.Commands;
using static Utilities;
/// <summary>
/// Creates a premium subscription for a cloud-hosted user with Stripe payment processing.
/// Handles customer creation, payment method setup, and subscription creation.
/// </summary>
public interface ICreatePremiumCloudHostedSubscriptionCommand
{
/// <summary>
/// Creates a premium cloud-hosted subscription for the specified user.
/// </summary>
/// <param name="user">The user to create the premium subscription for. Must not already be a premium user.</param>
/// <param name="paymentMethod">The tokenized payment method containing the payment type and token for billing.</param>
/// <param name="billingAddress">The billing address information required for tax calculation and customer creation.</param>
/// <param name="additionalStorageGb">Additional storage in GB beyond the base 1GB included with premium (must be >= 0).</param>
/// <returns>A billing command result indicating success or failure with appropriate error details.</returns>
Task<BillingCommandResult<None>> Run(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress,
short additionalStorageGb);
}
public class CreatePremiumCloudHostedSubscriptionCommand(
IBraintreeGateway braintreeGateway,
IGlobalSettings globalSettings,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserService userService,
IPushNotificationService pushNotificationService,
ILogger<CreatePremiumCloudHostedSubscriptionCommand> logger)
: BaseBillingCommand<CreatePremiumCloudHostedSubscriptionCommand>(logger), ICreatePremiumCloudHostedSubscriptionCommand
{
private static readonly List<string> _expand = ["tax"];
private readonly ILogger<CreatePremiumCloudHostedSubscriptionCommand> _logger = logger;
public Task<BillingCommandResult<None>> Run(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress,
short additionalStorageGb) => HandleAsync<None>(async () =>
{
if (user.Premium)
{
return new BadRequest("Already a premium user.");
}
if (additionalStorageGb < 0)
{
return new BadRequest("Additional storage must be greater than 0.");
}
var customer = string.IsNullOrEmpty(user.GatewayCustomerId)
? await CreateCustomerAsync(user, paymentMethod, billingAddress)
: await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand });
customer = await ReconcileBillingLocationAsync(customer, billingAddress);
var subscription = await CreateSubscriptionAsync(user.Id, customer, additionalStorageGb > 0 ? additionalStorageGb : null);
switch (paymentMethod)
{
case { Type: TokenizablePaymentMethodType.PayPal }
when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete:
case { Type: not TokenizablePaymentMethodType.PayPal }
when subscription.Status == StripeConstants.SubscriptionStatus.Active:
{
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
break;
}
}
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
user.GatewaySubscriptionId = subscription.Id;
user.MaxStorageGb = (short)(1 + additionalStorageGb);
user.LicenseKey = CoreHelpers.SecureRandomString(20);
user.RevisionDate = DateTime.UtcNow;
await userService.SaveUserAsync(user);
await pushNotificationService.PushSyncVaultAsync(user.Id);
return new None();
});
private async Task<Customer> CreateCustomerAsync(User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
var subscriberName = user.SubscriberName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Line1 = billingAddress.Line1,
Line2 = billingAddress.Line2,
City = billingAddress.City,
PostalCode = billingAddress.PostalCode,
State = billingAddress.State,
Country = billingAddress.Country
},
Description = user.Name,
Email = user.Email,
Expand = _expand,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = user.SubscriberType(),
Value = subscriberName.Length <= 30
? subscriberName
: subscriberName[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion,
[StripeConstants.MetadataKeys.UserId] = user.Id.ToString()
},
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
var braintreeCustomerId = "";
// ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault
switch (paymentMethod.Type)
{
case TokenizablePaymentMethodType.BankAccount:
{
var setupIntent =
(await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token }))
.FirstOrDefault();
if (setupIntent == null)
{
_logger.LogError("Cannot create customer for user ({UserID}) without a setup intent for their bank account", user.Id);
throw new BillingException();
}
await setupIntentCache.Set(user.Id, setupIntent.Id);
break;
}
case TokenizablePaymentMethodType.Card:
{
customerCreateOptions.PaymentMethod = paymentMethod.Token;
customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token;
break;
}
case TokenizablePaymentMethodType.PayPal:
{
braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.Token);
customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId;
break;
}
default:
{
_logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.Type.ToString());
throw new BillingException();
}
}
try
{
return await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
}
catch
{
await Revert();
throw;
}
async Task Revert()
{
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault
switch (paymentMethod.Type)
{
case TokenizablePaymentMethodType.BankAccount:
{
await setupIntentCache.Remove(user.Id);
break;
}
case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId):
{
await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId);
break;
}
}
}
}
private async Task<Customer> ReconcileBillingLocationAsync(
Customer customer,
BillingAddress billingAddress)
{
/*
* If the customer was previously set up with credit, which does not require a billing location,
* we need to update the customer on the fly before we start the subscription.
*/
if (customer is { Address: { Country: not null and not "", PostalCode: not null and not "" } })
{
return customer;
}
var options = new CustomerUpdateOptions
{
Address = new AddressOptions
{
Line1 = billingAddress.Line1,
Line2 = billingAddress.Line2,
City = billingAddress.City,
PostalCode = billingAddress.PostalCode,
State = billingAddress.State,
Country = billingAddress.Country
},
Expand = _expand,
Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}
};
return await stripeAdapter.CustomerUpdateAsync(customer.Id, options);
}
private async Task<Subscription> CreateSubscriptionAsync(
Guid userId,
Customer customer,
int? storage)
{
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>
{
new ()
{
Price = StripeConstants.Prices.PremiumAnnually,
Quantity = 1
}
};
if (storage is > 0)
{
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
Price = StripeConstants.Prices.StoragePlanPersonal,
Quantity = storage
});
}
var usingPayPal = customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) ?? false;
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,
Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.UserId] = userId.ToString()
},
PaymentBehavior = usingPayPal
? StripeConstants.PaymentBehavior.DefaultIncomplete
: null,
OffSession = true
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal)
{
await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions
{
AutoAdvance = false
});
}
return subscription;
}
}

View File

@@ -0,0 +1,67 @@
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using OneOf.Types;
namespace Bit.Core.Billing.Premium.Commands;
/// <summary>
/// Creates a premium subscription for a self-hosted user.
/// Validates the license and applies premium benefits including storage limits based on the license terms.
/// </summary>
public interface ICreatePremiumSelfHostedSubscriptionCommand
{
/// <summary>
/// Creates a premium self-hosted subscription for the specified user using the provided license.
/// </summary>
/// <param name="user">The user to create the premium subscription for. Must not already be a premium user.</param>
/// <param name="license">The user license containing the premium subscription details and verification data. Must be valid and usable by the specified user.</param>
/// <returns>A billing command result indicating success or failure with appropriate error details.</returns>
Task<BillingCommandResult<None>> Run(User user, UserLicense license);
}
public class CreatePremiumSelfHostedSubscriptionCommand(
ILicensingService licensingService,
IUserService userService,
IPushNotificationService pushNotificationService,
ILogger<CreatePremiumSelfHostedSubscriptionCommand> logger)
: BaseBillingCommand<CreatePremiumSelfHostedSubscriptionCommand>(logger), ICreatePremiumSelfHostedSubscriptionCommand
{
public Task<BillingCommandResult<None>> Run(
User user,
UserLicense license) => HandleAsync<None>(async () =>
{
if (user.Premium)
{
return new BadRequest("Already a premium user.");
}
if (!licensingService.VerifyLicense(license))
{
return new BadRequest("Invalid license.");
}
var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license);
if (!license.CanUse(user, claimsPrincipal, out var exceptionMessage))
{
return new BadRequest(exceptionMessage);
}
await licensingService.WriteUserLicenseAsync(user, license);
user.Premium = true;
user.RevisionDate = DateTime.UtcNow;
user.MaxStorageGb = Core.Constants.SelfHostedMaxStorageGb;
user.LicenseKey = license.LicenseKey;
user.PremiumExpirationDate = license.Expires;
await userService.SaveUserAsync(user);
await pushNotificationService.PushSyncVaultAsync(user.Id);
return new None();
});
}

View File

@@ -26,4 +26,5 @@ public interface ILicensingService
SubscriptionInfo subscriptionInfo);
Task<string?> CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo);
Task WriteUserLicenseAsync(User user, UserLicense license);
}

View File

@@ -389,4 +389,12 @@ public class LicensingService : ILicensingService
var token = tokenHandler.CreateToken(tokenDescriptor);
return tokenHandler.WriteToken(token);
}
public async Task WriteUserLicenseAsync(User user, UserLicense license)
{
var dir = $"{_globalSettings.LicenseDirectory}/user";
Directory.CreateDirectory(dir);
await using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json"));
await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented);
}
}

View File

@@ -304,7 +304,7 @@ public class PremiumUserBillingService(
{
new ()
{
Price = "premium-annually",
Price = StripeConstants.Prices.PremiumAnnually,
Quantity = 1
}
};

View File

@@ -73,4 +73,9 @@ public class NoopLicensingService : ILicensingService
{
return Task.FromResult<string?>(null);
}
public Task WriteUserLicenseAsync(User user, UserLicense license)
{
return Task.CompletedTask;
}
}

View File

@@ -10,6 +10,11 @@ public static class Constants
public const int BypassFiltersEventId = 12482444;
public const int FailedSecretVerificationDelay = 2000;
/// <summary>
/// Self-hosted max storage limit in GB (10 TB).
/// </summary>
public const short SelfHostedMaxStorageGb = 10240;
// File size limits - give 1 MB extra for cushion.
// Note: if request size limits are changed, 'client_max_body_size'
// in nginx/proxy.conf may also need to be updated accordingly.
@@ -166,6 +171,7 @@ public static class FeatureFlagKeys
public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout";
public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover";
public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings";
public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow";
/* Key Management Team */
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";

View File

@@ -906,7 +906,7 @@ public class StripePaymentService : IPaymentService
new()
{
Quantity = 1,
Plan = "premium-annually"
Plan = StripeConstants.Prices.PremiumAnnually
},
new()

View File

@@ -44,8 +44,6 @@ namespace Bit.Core.Services;
public class UserService : UserManager<User>, IUserService
{
private const string PremiumPlanId = "premium-annually";
private readonly IUserRepository _userRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationRepository _organizationRepository;
@@ -930,7 +928,7 @@ public class UserService : UserManager<User>, IUserService
if (_globalSettings.SelfHosted)
{
user.MaxStorageGb = 10240; // 10 TB
user.MaxStorageGb = Constants.SelfHostedMaxStorageGb;
user.LicenseKey = license.LicenseKey;
user.PremiumExpirationDate = license.Expires;
}
@@ -989,7 +987,7 @@ public class UserService : UserManager<User>, IUserService
user.Premium = license.Premium;
user.RevisionDate = DateTime.UtcNow;
user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB
user.MaxStorageGb = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : license.MaxStorageGb;
user.LicenseKey = license.LicenseKey;
user.PremiumExpirationDate = license.Expires;
await SaveUserAsync(user);

View File

@@ -125,7 +125,7 @@ public class SendValidationService : ISendValidationService
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
short limit = _globalSettings.SelfHosted ? (short)10240 : (short)1;
short limit = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1;
storageBytesRemaining = user.StorageBytesRemaining(limit);
}
}

View File

@@ -933,7 +933,7 @@ public class CipherService : ICipherService
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
storageBytesRemaining = user.StorageBytesRemaining(
_globalSettings.SelfHosted ? (short)10240 : (short)1);
_globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1);
}
}
else if (cipher.OrganizationId.HasValue)

View File

@@ -0,0 +1,477 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Test.Common.AutoFixture.Attributes;
using Braintree;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
using Address = Stripe.Address;
using StripeCustomer = Stripe.Customer;
using StripeSubscription = Stripe.Subscription;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class CreatePremiumCloudHostedSubscriptionCommandTests
{
private readonly IBraintreeGateway _braintreeGateway = Substitute.For<IBraintreeGateway>();
private readonly IGlobalSettings _globalSettings = Substitute.For<IGlobalSettings>();
private readonly ISetupIntentCache _setupIntentCache = Substitute.For<ISetupIntentCache>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPushNotificationService _pushNotificationService = Substitute.For<IPushNotificationService>();
private readonly CreatePremiumCloudHostedSubscriptionCommand _command;
public CreatePremiumCloudHostedSubscriptionCommandTests()
{
var baseServiceUri = Substitute.For<IBaseServiceUriSettings>();
baseServiceUri.CloudRegion.Returns("US");
_globalSettings.BaseServiceUri.Returns(baseServiceUri);
_command = new CreatePremiumCloudHostedSubscriptionCommand(
_braintreeGateway,
_globalSettings,
_setupIntentCache,
_stripeAdapter,
_subscriberService,
_userService,
_pushNotificationService,
Substitute.For<ILogger<CreatePremiumCloudHostedSubscriptionCommand>>());
}
[Theory, BitAutoData]
public async Task Run_UserAlreadyPremium_ReturnsBadRequest(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = true;
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Already a premium user.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_NegativeStorageAmount_ReturnsBadRequest(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, -1);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Additional storage must be greater than 0.", badRequest.Response);
}
[Theory, BitAutoData]
public async Task Run_ValidPaymentMethodTypes_BankAccount_Success(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null; // Ensure no existing customer ID
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.BankAccount;
paymentMethod.Token = "bank_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
var mockInvoice = Substitute.For<Invoice>();
var mockSetupIntent = Substitute.For<SetupIntent>();
mockSetupIntent.Id = "seti_123";
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_stripeAdapter.SetupIntentList(Arg.Any<SetupIntentListOptions>()).Returns(Task.FromResult(new List<SetupIntent> { mockSetupIntent }));
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]
public async Task Run_ValidPaymentMethodTypes_Card_Success(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.Card;
paymentMethod.Token = "card_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]
public async Task Run_ValidPaymentMethodTypes_PayPal_Success(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.PayPal;
paymentMethod.Token = "paypal_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_subscriberService.CreateBraintreeCustomer(Arg.Any<User>(), Arg.Any<string>()).Returns("bt_customer_123");
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>());
await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]
public async Task Run_ValidRequestWithAdditionalStorage_Success(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.Card;
paymentMethod.Token = "card_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
const short additionalStorage = 2;
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30);
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, additionalStorage);
// Assert
Assert.True(result.IsT0);
Assert.True(user.Premium);
Assert.Equal((short)(1 + additionalStorage), user.MaxStorageGb);
Assert.NotNull(user.LicenseKey);
Assert.Equal(20, user.LicenseKey.Length);
Assert.NotEqual(default, user.RevisionDate);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
[Theory, BitAutoData]
public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = "existing_customer_123";
paymentMethod.Type = TokenizablePaymentMethodType.Card;
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "existing_customer_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
var mockInvoice = Substitute.For<Invoice>();
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>());
await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any<CustomerCreateOptions>());
}
[Theory, BitAutoData]
public async Task Run_PayPalWithIncompleteSubscription_SetsPremiumTrue(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
user.PremiumExpirationDate = null;
paymentMethod.Type = TokenizablePaymentMethodType.PayPal;
paymentMethod.Token = "paypal_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "incomplete";
mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30);
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.CreateBraintreeCustomer(Arg.Any<User>(), Arg.Any<string>()).Returns("bt_customer_123");
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
Assert.True(user.Premium);
Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate);
}
[Theory, BitAutoData]
public async Task Run_NonPayPalWithActiveSubscription_SetsPremiumTrue(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.Card;
paymentMethod.Token = "card_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active";
mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30);
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
Assert.True(user.Premium);
Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate);
}
[Theory, BitAutoData]
public async Task Run_SubscriptionStatusDoesNotMatchPatterns_DoesNotSetPremium(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
user.PremiumExpirationDate = null;
paymentMethod.Type = TokenizablePaymentMethodType.PayPal;
paymentMethod.Token = "paypal_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "active"; // PayPal + active doesn't match pattern
mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30);
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.CreateBraintreeCustomer(Arg.Any<User>(), Arg.Any<string>()).Returns("bt_customer_123");
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT0);
Assert.False(user.Premium);
Assert.Null(user.PremiumExpirationDate);
}
[Theory, BitAutoData]
public async Task Run_BankAccountWithNoSetupIntentFound_ReturnsUnhandled(
User user,
TokenizedPaymentMethod paymentMethod,
BillingAddress billingAddress)
{
// Arrange
user.Premium = false;
user.GatewayCustomerId = null;
user.Email = "test@example.com";
paymentMethod.Type = TokenizablePaymentMethodType.BankAccount;
paymentMethod.Token = "bank_token_123";
billingAddress.Country = "US";
billingAddress.PostalCode = "12345";
var mockCustomer = Substitute.For<StripeCustomer>();
mockCustomer.Id = "cust_123";
mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" };
mockCustomer.Metadata = new Dictionary<string, string>();
var mockSubscription = Substitute.For<StripeSubscription>();
mockSubscription.Id = "sub_123";
mockSubscription.Status = "incomplete";
mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30);
var mockInvoice = Substitute.For<Invoice>();
_stripeAdapter.CustomerCreateAsync(Arg.Any<CustomerCreateOptions>()).Returns(mockCustomer);
_stripeAdapter.CustomerUpdateAsync(Arg.Any<string>(), Arg.Any<CustomerUpdateOptions>()).Returns(mockCustomer);
_stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(mockSubscription);
_stripeAdapter.InvoiceUpdateAsync(Arg.Any<string>(), Arg.Any<InvoiceUpdateOptions>()).Returns(mockInvoice);
_subscriberService.GetCustomerOrThrow(Arg.Any<User>(), Arg.Any<CustomerGetOptions>()).Returns(mockCustomer);
_stripeAdapter.SetupIntentList(Arg.Any<SetupIntentListOptions>())
.Returns(Task.FromResult(new List<SetupIntent>())); // Empty list - no setup intent found
// Act
var result = await _command.Run(user, paymentMethod, billingAddress, 0);
// Assert
Assert.True(result.IsT3);
var unhandled = result.AsT3;
Assert.Equal("Something went wrong with your request. Please contact support for assistance.", unhandled.Response);
}
}

View File

@@ -0,0 +1,199 @@
using System.Security.Claims;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Premium.Commands;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Billing.Premium.Commands;
public class CreatePremiumSelfHostedSubscriptionCommandTests
{
private readonly ILicensingService _licensingService = Substitute.For<ILicensingService>();
private readonly IUserService _userService = Substitute.For<IUserService>();
private readonly IPushNotificationService _pushNotificationService = Substitute.For<IPushNotificationService>();
private readonly CreatePremiumSelfHostedSubscriptionCommand _command;
public CreatePremiumSelfHostedSubscriptionCommandTests()
{
_command = new CreatePremiumSelfHostedSubscriptionCommand(
_licensingService,
_userService,
_pushNotificationService,
Substitute.For<ILogger<CreatePremiumSelfHostedSubscriptionCommand>>());
}
[Fact]
public async Task Run_UserAlreadyPremium_ReturnsBadRequest()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Premium = true
};
var license = new UserLicense
{
LicenseKey = "test_key",
Expires = DateTime.UtcNow.AddYears(1)
};
// Act
var result = await _command.Run(user, license);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Already a premium user.", badRequest.Response);
}
[Fact]
public async Task Run_InvalidLicense_ReturnsBadRequest()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Premium = false
};
var license = new UserLicense
{
LicenseKey = "invalid_key",
Expires = DateTime.UtcNow.AddYears(1)
};
_licensingService.VerifyLicense(license).Returns(false);
// Act
var result = await _command.Run(user, license);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Equal("Invalid license.", badRequest.Response);
}
[Fact]
public async Task Run_LicenseCannotBeUsed_EmailNotVerified_ReturnsBadRequest()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Premium = false,
Email = "test@example.com",
EmailVerified = false
};
var license = new UserLicense
{
LicenseKey = "test_key",
Expires = DateTime.UtcNow.AddYears(1),
Token = "valid_token"
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[]
{
new Claim("Email", "test@example.com")
}));
_licensingService.VerifyLicense(license).Returns(true);
_licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
// Act
var result = await _command.Run(user, license);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Contains("The user's email is not verified.", badRequest.Response);
}
[Fact]
public async Task Run_LicenseCannotBeUsed_EmailMismatch_ReturnsBadRequest()
{
// Arrange
var user = new User
{
Id = Guid.NewGuid(),
Premium = false,
Email = "user@example.com",
EmailVerified = true
};
var license = new UserLicense
{
LicenseKey = "test_key",
Expires = DateTime.UtcNow.AddYears(1),
Token = "valid_token"
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[]
{
new Claim("Email", "license@example.com")
}));
_licensingService.VerifyLicense(license).Returns(true);
_licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
// Act
var result = await _command.Run(user, license);
// Assert
Assert.True(result.IsT1);
var badRequest = result.AsT1;
Assert.Contains("The user's email does not match the license email.", badRequest.Response);
}
[Fact]
public async Task Run_ValidRequest_Success()
{
// Arrange
var userId = Guid.NewGuid();
var user = new User
{
Id = userId,
Premium = false,
Email = "test@example.com",
EmailVerified = true
};
var license = new UserLicense
{
LicenseKey = "test_key_12345",
Expires = DateTime.UtcNow.AddYears(1),
Token = "valid_token"
};
var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[]
{
new Claim("Email", "test@example.com")
}));
_licensingService.VerifyLicense(license).Returns(true);
_licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
// Act
var result = await _command.Run(user, license);
// Assert
Assert.True(result.IsT0);
// Verify user was updated correctly
Assert.True(user.Premium);
Assert.NotNull(user.LicenseKey);
Assert.Equal(license.LicenseKey, user.LicenseKey);
Assert.NotEqual(default, user.RevisionDate);
// Verify services were called
await _licensingService.Received(1).WriteUserLicenseAsync(user, license);
await _userService.Received(1).SaveUserAsync(user);
await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id);
}
}

View File

@@ -1,8 +1,10 @@
using System.Text.Json;
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Settings;
using Bit.Core.Test.Billing.AutoFixture;
using Bit.Test.Common.AutoFixture;
@@ -16,6 +18,8 @@ public class LicensingServiceTests
{
private static string licenseFilePath(Guid orgId) =>
Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json");
private static string userLicenseFilePath(Guid userId) =>
Path.Combine(UserLicenseDirectory.Value, $"{userId}.json");
private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value);
private static Lazy<string> OrganizationLicenseDirectory => new(() =>
{
@@ -26,6 +30,15 @@ public class LicensingServiceTests
}
return directory;
});
private static Lazy<string> UserLicenseDirectory => new(() =>
{
var directory = Path.Combine(Path.GetTempPath(), "user");
if (!Directory.Exists(directory))
{
Directory.CreateDirectory(directory);
}
return directory;
});
public static SutProvider<LicensingService> GetSutProvider()
{
@@ -57,4 +70,66 @@ public class LicensingServiceTests
Directory.Delete(OrganizationLicenseDirectory.Value, true);
}
}
[Theory, BitAutoData]
public async Task WriteUserLicense_CreatesFileWithCorrectContent(User user, UserLicense license)
{
// Arrange
var sutProvider = GetSutProvider();
var expectedFilePath = userLicenseFilePath(user.Id);
try
{
// Act
await sutProvider.Sut.WriteUserLicenseAsync(user, license);
// Assert
Assert.True(File.Exists(expectedFilePath));
var fileContent = await File.ReadAllTextAsync(expectedFilePath);
var actualLicense = JsonSerializer.Deserialize<UserLicense>(fileContent);
Assert.Equal(license.LicenseKey, actualLicense.LicenseKey);
Assert.Equal(license.Id, actualLicense.Id);
Assert.Equal(license.Expires, actualLicense.Expires);
}
finally
{
// Cleanup
if (Directory.Exists(UserLicenseDirectory.Value))
{
Directory.Delete(UserLicenseDirectory.Value, true);
}
}
}
[Theory, BitAutoData]
public async Task WriteUserLicense_CreatesDirectoryIfNotExists(User user, UserLicense license)
{
// Arrange
var sutProvider = GetSutProvider();
// Ensure directory doesn't exist
if (Directory.Exists(UserLicenseDirectory.Value))
{
Directory.Delete(UserLicenseDirectory.Value, true);
}
try
{
// Act
await sutProvider.Sut.WriteUserLicenseAsync(user, license);
// Assert
Assert.True(Directory.Exists(UserLicenseDirectory.Value));
Assert.True(File.Exists(userLicenseFilePath(user.Id)));
}
finally
{
// Cleanup
if (Directory.Exists(UserLicenseDirectory.Value))
{
Directory.Delete(UserLicenseDirectory.Value, true);
}
}
}
}