From a458db319ea579f412b4760d18bf2c922bee2a86 Mon Sep 17 00:00:00 2001 From: Kyle Denney <4227399+kdenney@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:08:22 -0500 Subject: [PATCH] [PM-25088] - refactor premium purchase endpoint (#6262) * [PM-25088] add feature flag for new premium subscription flow * [PM-25088] refactor premium endpoint * forgot the punctuation change in the test * [PM-25088] - pr feedback * [PM-25088] - pr feedback round two --- .../PaymentMethodTypeValidationAttribute.cs | 13 + .../VNext/AccountBillingVNextController.cs | 17 + .../SelfHostedAccountBillingController.cs | 38 ++ .../MinimalTokenizedPaymentMethodRequest.cs | 25 + .../Payment/TokenizedPaymentMethodRequest.cs | 14 +- .../PremiumCloudHostedSubscriptionRequest.cs | 26 + .../PremiumSelfHostedSubscriptionRequest.cs | 10 + .../Services/OrganizationFactory.cs | 4 +- src/Core/Billing/Constants/StripeConstants.cs | 1 + .../Extensions/ServiceCollectionExtensions.cs | 8 + .../Models/TokenizablePaymentMethodType.cs | 14 + ...tePremiumCloudHostedSubscriptionCommand.cs | 308 +++++++++++ ...atePremiumSelfHostedSubscriptionCommand.cs | 67 +++ .../Billing/Services/ILicensingService.cs | 1 + .../Implementations/LicensingService.cs | 8 + .../PremiumUserBillingService.cs | 2 +- .../NoopLicensingService.cs | 5 + src/Core/Constants.cs | 6 + .../Implementations/StripePaymentService.cs | 2 +- .../Services/Implementations/UserService.cs | 6 +- .../Services/SendValidationService.cs | 2 +- .../Services/Implementations/CipherService.cs | 2 +- ...miumCloudHostedSubscriptionCommandTests.cs | 477 ++++++++++++++++++ ...emiumSelfHostedSubscriptionCommandTests.cs | 199 ++++++++ .../Billing/Services/LicensingServiceTests.cs | 75 +++ 25 files changed, 1309 insertions(+), 21 deletions(-) create mode 100644 src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs create mode 100644 src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Premium/PremiumSelfHostedSubscriptionRequest.cs create mode 100644 src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs create mode 100644 src/Core/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommand.cs create mode 100644 test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs create mode 100644 test/Core.Test/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommandTests.cs diff --git a/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs b/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs new file mode 100644 index 0000000000..227b454f9f --- /dev/null +++ b/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs @@ -0,0 +1,13 @@ +using Bit.Api.Utilities; + +namespace Bit.Api.Billing.Attributes; + +public class PaymentMethodTypeValidationAttribute : StringMatchesAttribute +{ + private static readonly string[] _acceptedValues = ["bankAccount", "card", "payPal"]; + + public PaymentMethodTypeValidationAttribute() : base(_acceptedValues) + { + ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}"; + } +} diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index e3b702e36d..a996290507 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -1,8 +1,11 @@ #nullable enable using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.Premium; +using Bit.Core; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Premium.Commands; using Bit.Core.Entities; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -16,6 +19,7 @@ namespace Bit.Api.Billing.Controllers.VNext; [SelfHosted(NotSelfHostedOnly = true)] public class AccountBillingVNextController( ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, + ICreatePremiumCloudHostedSubscriptionCommand createPremiumCloudHostedSubscriptionCommand, IGetCreditQuery getCreditQuery, IGetPaymentMethodQuery getPaymentMethodQuery, IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController @@ -61,4 +65,17 @@ public class AccountBillingVNextController( var result = await updatePaymentMethodCommand.Run(user, paymentMethod, billingAddress); return Handle(result); } + + [HttpPost("subscription")] + [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [InjectUser] + public async Task CreateSubscriptionAsync( + [BindNever] User user, + [FromBody] PremiumCloudHostedSubscriptionRequest request) + { + var (paymentMethod, billingAddress, additionalStorageGb) = request.ToDomain(); + var result = await createPremiumCloudHostedSubscriptionCommand.Run( + user, paymentMethod, billingAddress, additionalStorageGb); + return Handle(result); + } } diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs new file mode 100644 index 0000000000..544753ad0f --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs @@ -0,0 +1,38 @@ +#nullable enable +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Premium; +using Bit.Api.Utilities; +using Bit.Core; +using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; + +namespace Bit.Api.Billing.Controllers.VNext; + +[Authorize("Application")] +[Route("account/billing/vnext/self-host")] +[SelfHosted(SelfHostedOnly = true)] +public class SelfHostedAccountBillingController( + ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController +{ + [HttpPost("license")] + [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [InjectUser] + public async Task UploadLicenseAsync( + [BindNever] User user, + PremiumSelfHostedSubscriptionRequest request) + { + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, request.License); + if (license == null) + { + throw new BadRequestException("Invalid license."); + } + var result = await createPremiumSelfHostedSubscriptionCommand.Run(user, license); + return Handle(result); + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs new file mode 100644 index 0000000000..3b50d2bf63 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs @@ -0,0 +1,25 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Attributes; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public class MinimalTokenizedPaymentMethodRequest +{ + [Required] + [PaymentMethodTypeValidation] + public required string Type { get; set; } + + [Required] + public required string Token { get; set; } + + public TokenizedPaymentMethod ToDomain() + { + return new TokenizedPaymentMethod + { + Type = TokenizablePaymentMethodTypeExtensions.From(Type), + Token = Token + }; + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs index 663e4e7cd2..f540957a1a 100644 --- a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs @@ -1,6 +1,6 @@ #nullable enable using System.ComponentModel.DataAnnotations; -using Bit.Api.Utilities; +using Bit.Api.Billing.Attributes; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; @@ -8,8 +8,7 @@ namespace Bit.Api.Billing.Models.Requests.Payment; public class TokenizedPaymentMethodRequest { [Required] - [StringMatches("bankAccount", "card", "payPal", - ErrorMessage = "Payment method type must be one of: bankAccount, card, payPal")] + [PaymentMethodTypeValidation] public required string Type { get; set; } [Required] @@ -21,14 +20,7 @@ public class TokenizedPaymentMethodRequest { var paymentMethod = new TokenizedPaymentMethod { - Type = Type switch - { - "bankAccount" => TokenizablePaymentMethodType.BankAccount, - "card" => TokenizablePaymentMethodType.Card, - "payPal" => TokenizablePaymentMethodType.PayPal, - _ => throw new InvalidOperationException( - $"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}") - }, + Type = TokenizablePaymentMethodTypeExtensions.From(Type), Token = Token }; diff --git a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs new file mode 100644 index 0000000000..b958057f5b --- /dev/null +++ b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs @@ -0,0 +1,26 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Premium; + +public class PremiumCloudHostedSubscriptionRequest +{ + [Required] + public required MinimalTokenizedPaymentMethodRequest TokenizedPaymentMethod { get; set; } + + [Required] + public required MinimalBillingAddressRequest BillingAddress { get; set; } + + [Range(0, 99)] + public short AdditionalStorageGb { get; set; } = 0; + + public (TokenizedPaymentMethod, BillingAddress, short) ToDomain() + { + var paymentMethod = TokenizedPaymentMethod.ToDomain(); + var billingAddress = BillingAddress.ToDomain(); + + return (paymentMethod, billingAddress, AdditionalStorageGb); + } +} diff --git a/src/Api/Billing/Models/Requests/Premium/PremiumSelfHostedSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Premium/PremiumSelfHostedSubscriptionRequest.cs new file mode 100644 index 0000000000..261544476e --- /dev/null +++ b/src/Api/Billing/Models/Requests/Premium/PremiumSelfHostedSubscriptionRequest.cs @@ -0,0 +1,10 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests.Premium; + +public class PremiumSelfHostedSubscriptionRequest +{ + [Required] + public required IFormFile License { get; set; } +} diff --git a/src/Core/AdminConsole/Services/OrganizationFactory.cs b/src/Core/AdminConsole/Services/OrganizationFactory.cs index dbc8f0fa21..afb3931ec4 100644 --- a/src/Core/AdminConsole/Services/OrganizationFactory.cs +++ b/src/Core/AdminConsole/Services/OrganizationFactory.cs @@ -23,7 +23,7 @@ public static class OrganizationFactory PlanType = claimsPrincipal.GetValue(OrganizationLicenseConstants.PlanType), Seats = claimsPrincipal.GetValue(OrganizationLicenseConstants.Seats), MaxCollections = claimsPrincipal.GetValue(OrganizationLicenseConstants.MaxCollections), - MaxStorageGb = 10240, + MaxStorageGb = Constants.SelfHostedMaxStorageGb, UsePolicies = claimsPrincipal.GetValue(OrganizationLicenseConstants.UsePolicies), UseSso = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseSso), UseKeyConnector = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseKeyConnector), @@ -75,7 +75,7 @@ public static class OrganizationFactory PlanType = license.PlanType, Seats = license.Seats, MaxCollections = license.MaxCollections, - MaxStorageGb = 10240, + MaxStorageGb = Constants.SelfHostedMaxStorageGb, UsePolicies = license.UsePolicies, UseSso = license.UseSso, UseKeyConnector = license.UseKeyConnector, diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 2be88902c8..131adfedf8 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -79,6 +79,7 @@ public static class StripeConstants public static class Prices { public const string StoragePlanPersonal = "personal-storage-gb-annually"; + public const string PremiumAnnually = "premium-annually"; } public static class ProrationBehavior diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 147e96105a..b4e37f0151 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -5,6 +5,7 @@ using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Payment; +using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -30,6 +31,7 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddPaymentOperations(); services.AddOrganizationLicenseCommandsQueries(); + services.AddPremiumCommands(); services.AddTransient(); } @@ -39,4 +41,10 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); } + + private static void AddPremiumCommands(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + } } diff --git a/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs index d27a924360..c198ec8230 100644 --- a/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs +++ b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs @@ -6,3 +6,17 @@ public enum TokenizablePaymentMethodType Card, PayPal } + +public static class TokenizablePaymentMethodTypeExtensions +{ + public static TokenizablePaymentMethodType From(string type) + { + return type switch + { + "bankAccount" => TokenizablePaymentMethodType.BankAccount, + "card" => TokenizablePaymentMethodType.Card, + "payPal" => TokenizablePaymentMethodType.PayPal, + _ => throw new InvalidOperationException($"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}") + }; + } +} diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs new file mode 100644 index 0000000000..8a73f31880 --- /dev/null +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -0,0 +1,308 @@ +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Braintree; +using Microsoft.Extensions.Logging; +using OneOf.Types; +using Stripe; +using Customer = Stripe.Customer; +using Subscription = Stripe.Subscription; + +namespace Bit.Core.Billing.Premium.Commands; + +using static Utilities; + +/// +/// Creates a premium subscription for a cloud-hosted user with Stripe payment processing. +/// Handles customer creation, payment method setup, and subscription creation. +/// +public interface ICreatePremiumCloudHostedSubscriptionCommand +{ + /// + /// Creates a premium cloud-hosted subscription for the specified user. + /// + /// The user to create the premium subscription for. Must not already be a premium user. + /// The tokenized payment method containing the payment type and token for billing. + /// The billing address information required for tax calculation and customer creation. + /// Additional storage in GB beyond the base 1GB included with premium (must be >= 0). + /// A billing command result indicating success or failure with appropriate error details. + Task> Run( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress, + short additionalStorageGb); +} + +public class CreatePremiumCloudHostedSubscriptionCommand( + IBraintreeGateway braintreeGateway, + IGlobalSettings globalSettings, + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService, + IUserService userService, + IPushNotificationService pushNotificationService, + ILogger logger) + : BaseBillingCommand(logger), ICreatePremiumCloudHostedSubscriptionCommand +{ + private static readonly List _expand = ["tax"]; + private readonly ILogger _logger = logger; + + public Task> Run( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress, + short additionalStorageGb) => HandleAsync(async () => + { + if (user.Premium) + { + return new BadRequest("Already a premium user."); + } + + if (additionalStorageGb < 0) + { + return new BadRequest("Additional storage must be greater than 0."); + } + + var customer = string.IsNullOrEmpty(user.GatewayCustomerId) + ? await CreateCustomerAsync(user, paymentMethod, billingAddress) + : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + + customer = await ReconcileBillingLocationAsync(customer, billingAddress); + + var subscription = await CreateSubscriptionAsync(user.Id, customer, additionalStorageGb > 0 ? additionalStorageGb : null); + + switch (paymentMethod) + { + case { Type: TokenizablePaymentMethodType.PayPal } + when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete: + case { Type: not TokenizablePaymentMethodType.PayPal } + when subscription.Status == StripeConstants.SubscriptionStatus.Active: + { + user.Premium = true; + user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + break; + } + } + + user.Gateway = GatewayType.Stripe; + user.GatewayCustomerId = customer.Id; + user.GatewaySubscriptionId = subscription.Id; + user.MaxStorageGb = (short)(1 + additionalStorageGb); + user.LicenseKey = CoreHelpers.SecureRandomString(20); + user.RevisionDate = DateTime.UtcNow; + + await userService.SaveUserAsync(user); + await pushNotificationService.PushSyncVaultAsync(user.Id); + + return new None(); + }); + + private async Task CreateCustomerAsync(User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + var subscriberName = user.SubscriberName(); + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + PostalCode = billingAddress.PostalCode, + State = billingAddress.State, + Country = billingAddress.Country + }, + Description = user.Name, + Email = user.Email, + Expand = _expand, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = user.SubscriberType(), + Value = subscriberName.Length <= 30 + ? subscriberName + : subscriberName[..30] + } + ] + }, + Metadata = new Dictionary + { + [StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [StripeConstants.MetadataKeys.UserId] = user.Id.ToString() + }, + Tax = new CustomerTaxOptions + { + ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + } + }; + + var braintreeCustomerId = ""; + + // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault + switch (paymentMethod.Type) + { + case TokenizablePaymentMethodType.BankAccount: + { + var setupIntent = + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) + .FirstOrDefault(); + + if (setupIntent == null) + { + _logger.LogError("Cannot create customer for user ({UserID}) without a setup intent for their bank account", user.Id); + throw new BillingException(); + } + + await setupIntentCache.Set(user.Id, setupIntent.Id); + break; + } + case TokenizablePaymentMethodType.Card: + { + customerCreateOptions.PaymentMethod = paymentMethod.Token; + customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token; + break; + } + case TokenizablePaymentMethodType.PayPal: + { + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.Token); + customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; + break; + } + default: + { + _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.Type.ToString()); + throw new BillingException(); + } + } + + try + { + return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + } + catch + { + await Revert(); + throw; + } + + async Task Revert() + { + // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault + switch (paymentMethod.Type) + { + case TokenizablePaymentMethodType.BankAccount: + { + await setupIntentCache.Remove(user.Id); + break; + } + case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + { + await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); + break; + } + } + } + } + + private async Task ReconcileBillingLocationAsync( + Customer customer, + BillingAddress billingAddress) + { + /* + * If the customer was previously set up with credit, which does not require a billing location, + * we need to update the customer on the fly before we start the subscription. + */ + if (customer is { Address: { Country: not null and not "", PostalCode: not null and not "" } }) + { + return customer; + } + + var options = new CustomerUpdateOptions + { + Address = new AddressOptions + { + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + PostalCode = billingAddress.PostalCode, + State = billingAddress.State, + Country = billingAddress.Country + }, + Expand = _expand, + Tax = new CustomerTaxOptions + { + ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + } + }; + return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + } + + private async Task CreateSubscriptionAsync( + Guid userId, + Customer customer, + int? storage) + { + var subscriptionItemOptionsList = new List + { + new () + { + Price = StripeConstants.Prices.PremiumAnnually, + Quantity = 1 + } + }; + + if (storage is > 0) + { + subscriptionItemOptionsList.Add(new SubscriptionItemOptions + { + Price = StripeConstants.Prices.StoragePlanPersonal, + Quantity = storage + }); + } + + var usingPayPal = customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) ?? false; + + var subscriptionCreateOptions = new SubscriptionCreateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }, + CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + Customer = customer.Id, + Items = subscriptionItemOptionsList, + Metadata = new Dictionary + { + [StripeConstants.MetadataKeys.UserId] = userId.ToString() + }, + PaymentBehavior = usingPayPal + ? StripeConstants.PaymentBehavior.DefaultIncomplete + : null, + OffSession = true + }; + + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + if (usingPayPal) + { + await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + { + AutoAdvance = false + }); + } + + return subscription; + } +} diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommand.cs new file mode 100644 index 0000000000..7546149ab6 --- /dev/null +++ b/src/Core/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommand.cs @@ -0,0 +1,67 @@ +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Platform.Push; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using OneOf.Types; + +namespace Bit.Core.Billing.Premium.Commands; + +/// +/// Creates a premium subscription for a self-hosted user. +/// Validates the license and applies premium benefits including storage limits based on the license terms. +/// +public interface ICreatePremiumSelfHostedSubscriptionCommand +{ + /// + /// Creates a premium self-hosted subscription for the specified user using the provided license. + /// + /// The user to create the premium subscription for. Must not already be a premium user. + /// The user license containing the premium subscription details and verification data. Must be valid and usable by the specified user. + /// A billing command result indicating success or failure with appropriate error details. + Task> Run(User user, UserLicense license); +} + +public class CreatePremiumSelfHostedSubscriptionCommand( + ILicensingService licensingService, + IUserService userService, + IPushNotificationService pushNotificationService, + ILogger logger) + : BaseBillingCommand(logger), ICreatePremiumSelfHostedSubscriptionCommand +{ + public Task> Run( + User user, + UserLicense license) => HandleAsync(async () => + { + if (user.Premium) + { + return new BadRequest("Already a premium user."); + } + + if (!licensingService.VerifyLicense(license)) + { + return new BadRequest("Invalid license."); + } + + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license); + if (!license.CanUse(user, claimsPrincipal, out var exceptionMessage)) + { + return new BadRequest(exceptionMessage); + } + + await licensingService.WriteUserLicenseAsync(user, license); + + user.Premium = true; + user.RevisionDate = DateTime.UtcNow; + user.MaxStorageGb = Core.Constants.SelfHostedMaxStorageGb; + user.LicenseKey = license.LicenseKey; + user.PremiumExpirationDate = license.Expires; + + await userService.SaveUserAsync(user); + await pushNotificationService.PushSyncVaultAsync(user.Id); + + return new None(); + }); +} diff --git a/src/Core/Billing/Services/ILicensingService.cs b/src/Core/Billing/Services/ILicensingService.cs index b6ada998a7..cd9847ea39 100644 --- a/src/Core/Billing/Services/ILicensingService.cs +++ b/src/Core/Billing/Services/ILicensingService.cs @@ -26,4 +26,5 @@ public interface ILicensingService SubscriptionInfo subscriptionInfo); Task CreateUserTokenAsync(User user, SubscriptionInfo subscriptionInfo); + Task WriteUserLicenseAsync(User user, UserLicense license); } diff --git a/src/Core/Billing/Services/Implementations/LicensingService.cs b/src/Core/Billing/Services/Implementations/LicensingService.cs index 81a52158ce..6f0cdec8f5 100644 --- a/src/Core/Billing/Services/Implementations/LicensingService.cs +++ b/src/Core/Billing/Services/Implementations/LicensingService.cs @@ -389,4 +389,12 @@ public class LicensingService : ILicensingService var token = tokenHandler.CreateToken(tokenDescriptor); return tokenHandler.WriteToken(token); } + + public async Task WriteUserLicenseAsync(User user, UserLicense license) + { + var dir = $"{_globalSettings.LicenseDirectory}/user"; + Directory.CreateDirectory(dir); + await using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + } } diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 986991ba0a..9db18278b6 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -304,7 +304,7 @@ public class PremiumUserBillingService( { new () { - Price = "premium-annually", + Price = StripeConstants.Prices.PremiumAnnually, Quantity = 1 } }; diff --git a/src/Core/Billing/Services/NoopImplementations/NoopLicensingService.cs b/src/Core/Billing/Services/NoopImplementations/NoopLicensingService.cs index a54ba3546a..b27e21a7c9 100644 --- a/src/Core/Billing/Services/NoopImplementations/NoopLicensingService.cs +++ b/src/Core/Billing/Services/NoopImplementations/NoopLicensingService.cs @@ -73,4 +73,9 @@ public class NoopLicensingService : ILicensingService { return Task.FromResult(null); } + + public Task WriteUserLicenseAsync(User user, UserLicense license) + { + return Task.CompletedTask; + } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 69003ee253..cba060427c 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -10,6 +10,11 @@ public static class Constants public const int BypassFiltersEventId = 12482444; public const int FailedSecretVerificationDelay = 2000; + /// + /// Self-hosted max storage limit in GB (10 TB). + /// + public const short SelfHostedMaxStorageGb = 10240; + // File size limits - give 1 MB extra for cushion. // Note: if request size limits are changed, 'client_max_body_size' // in nginx/proxy.conf may also need to be updated accordingly. @@ -166,6 +171,7 @@ public static class FeatureFlagKeys public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout"; public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; + public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index ec45944bd2..5b68906d8a 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -906,7 +906,7 @@ public class StripePaymentService : IPaymentService new() { Quantity = 1, - Plan = "premium-annually" + Plan = StripeConstants.Prices.PremiumAnnually }, new() diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 16e298d177..386cb8c3d2 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -44,8 +44,6 @@ namespace Bit.Core.Services; public class UserService : UserManager, IUserService { - private const string PremiumPlanId = "premium-annually"; - private readonly IUserRepository _userRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationRepository _organizationRepository; @@ -930,7 +928,7 @@ public class UserService : UserManager, IUserService if (_globalSettings.SelfHosted) { - user.MaxStorageGb = 10240; // 10 TB + user.MaxStorageGb = Constants.SelfHostedMaxStorageGb; user.LicenseKey = license.LicenseKey; user.PremiumExpirationDate = license.Expires; } @@ -989,7 +987,7 @@ public class UserService : UserManager, IUserService user.Premium = license.Premium; user.RevisionDate = DateTime.UtcNow; - user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB + user.MaxStorageGb = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : license.MaxStorageGb; user.LicenseKey = license.LicenseKey; user.PremiumExpirationDate = license.Expires; await SaveUserAsync(user); diff --git a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs index c6dd3b1dc9..c545c8b35f 100644 --- a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs +++ b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs @@ -125,7 +125,7 @@ public class SendValidationService : ISendValidationService { // Users that get access to file storage/premium from their organization get the default // 1 GB max storage. - short limit = _globalSettings.SelfHosted ? (short)10240 : (short)1; + short limit = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1; storageBytesRemaining = user.StorageBytesRemaining(limit); } } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index 2a4cc6c137..e0b121fdd3 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -933,7 +933,7 @@ public class CipherService : ICipherService // Users that get access to file storage/premium from their organization get the default // 1 GB max storage. storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? (short)10240 : (short)1); + _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1); } } else if (cipher.OrganizationId.HasValue) diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs new file mode 100644 index 0000000000..e808fb10b0 --- /dev/null +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -0,0 +1,477 @@ +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Platform.Push; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Test.Common.AutoFixture.Attributes; +using Braintree; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; +using Address = Stripe.Address; +using StripeCustomer = Stripe.Customer; +using StripeSubscription = Stripe.Subscription; + +namespace Bit.Core.Test.Billing.Premium.Commands; + +public class CreatePremiumCloudHostedSubscriptionCommandTests +{ + private readonly IBraintreeGateway _braintreeGateway = Substitute.For(); + private readonly IGlobalSettings _globalSettings = Substitute.For(); + private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly IUserService _userService = Substitute.For(); + private readonly IPushNotificationService _pushNotificationService = Substitute.For(); + private readonly CreatePremiumCloudHostedSubscriptionCommand _command; + + public CreatePremiumCloudHostedSubscriptionCommandTests() + { + var baseServiceUri = Substitute.For(); + baseServiceUri.CloudRegion.Returns("US"); + _globalSettings.BaseServiceUri.Returns(baseServiceUri); + + _command = new CreatePremiumCloudHostedSubscriptionCommand( + _braintreeGateway, + _globalSettings, + _setupIntentCache, + _stripeAdapter, + _subscriberService, + _userService, + _pushNotificationService, + Substitute.For>()); + } + + [Theory, BitAutoData] + public async Task Run_UserAlreadyPremium_ReturnsBadRequest( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Already a premium user.", badRequest.Response); + } + + [Theory, BitAutoData] + public async Task Run_NegativeStorageAmount_ReturnsBadRequest( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, -1); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Additional storage must be greater than 0.", badRequest.Response); + } + + [Theory, BitAutoData] + public async Task Run_ValidPaymentMethodTypes_BankAccount_Success( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; // Ensure no existing customer ID + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.BankAccount; + paymentMethod.Token = "bank_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + + var mockInvoice = Substitute.For(); + + var mockSetupIntent = Substitute.For(); + mockSetupIntent.Id = "seti_123"; + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.SetupIntentList(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); + } + + [Theory, BitAutoData] + public async Task Run_ValidPaymentMethodTypes_Card_Success( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); + } + + [Theory, BitAutoData] + public async Task Run_ValidPaymentMethodTypes_PayPal_Success( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.PayPal; + paymentMethod.Token = "paypal_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); + } + + [Theory, BitAutoData] + public async Task Run_ValidRequestWithAdditionalStorage_Success( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + const short additionalStorage = 2; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, additionalStorage); + + // Assert + Assert.True(result.IsT0); + Assert.True(user.Premium); + Assert.Equal((short)(1 + additionalStorage), user.MaxStorageGb); + Assert.NotNull(user.LicenseKey); + Assert.Equal(20, user.LicenseKey.Length); + Assert.NotEqual(default, user.RevisionDate); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); + } + + [Theory, BitAutoData] + public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = "existing_customer_123"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "existing_customer_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + + var mockInvoice = Substitute.For(); + + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_PayPalWithIncompleteSubscription_SetsPremiumTrue( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + user.PremiumExpirationDate = null; + paymentMethod.Type = TokenizablePaymentMethodType.PayPal; + paymentMethod.Token = "paypal_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "incomplete"; + mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + Assert.True(user.Premium); + Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate); + } + + [Theory, BitAutoData] + public async Task Run_NonPayPalWithActiveSubscription_SetsPremiumTrue( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + Assert.True(user.Premium); + Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate); + } + + [Theory, BitAutoData] + public async Task Run_SubscriptionStatusDoesNotMatchPatterns_DoesNotSetPremium( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + user.PremiumExpirationDate = null; + paymentMethod.Type = TokenizablePaymentMethodType.PayPal; + paymentMethod.Token = "paypal_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; // PayPal + active doesn't match pattern + mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + Assert.False(user.Premium); + Assert.Null(user.PremiumExpirationDate); + } + + [Theory, BitAutoData] + public async Task Run_BankAccountWithNoSetupIntentFound_ReturnsUnhandled( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.BankAccount; + paymentMethod.Token = "bank_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "incomplete"; + mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var mockInvoice = Substitute.For(); + + _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + + _stripeAdapter.SetupIntentList(Arg.Any()) + .Returns(Task.FromResult(new List())); // Empty list - no setup intent found + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.Equal("Something went wrong with your request. Please contact support for assistance.", unhandled.Response); + } +} diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommandTests.cs new file mode 100644 index 0000000000..6dfd620e45 --- /dev/null +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumSelfHostedSubscriptionCommandTests.cs @@ -0,0 +1,199 @@ +using System.Security.Claims; +using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Platform.Push; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Billing.Premium.Commands; + +public class CreatePremiumSelfHostedSubscriptionCommandTests +{ + private readonly ILicensingService _licensingService = Substitute.For(); + private readonly IUserService _userService = Substitute.For(); + private readonly IPushNotificationService _pushNotificationService = Substitute.For(); + private readonly CreatePremiumSelfHostedSubscriptionCommand _command; + + public CreatePremiumSelfHostedSubscriptionCommandTests() + { + _command = new CreatePremiumSelfHostedSubscriptionCommand( + _licensingService, + _userService, + _pushNotificationService, + Substitute.For>()); + } + + [Fact] + public async Task Run_UserAlreadyPremium_ReturnsBadRequest() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Premium = true + }; + + var license = new UserLicense + { + LicenseKey = "test_key", + Expires = DateTime.UtcNow.AddYears(1) + }; + + // Act + var result = await _command.Run(user, license); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Already a premium user.", badRequest.Response); + } + + [Fact] + public async Task Run_InvalidLicense_ReturnsBadRequest() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Premium = false + }; + + var license = new UserLicense + { + LicenseKey = "invalid_key", + Expires = DateTime.UtcNow.AddYears(1) + }; + + _licensingService.VerifyLicense(license).Returns(false); + + // Act + var result = await _command.Run(user, license); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("Invalid license.", badRequest.Response); + } + + [Fact] + public async Task Run_LicenseCannotBeUsed_EmailNotVerified_ReturnsBadRequest() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Premium = false, + Email = "test@example.com", + EmailVerified = false + }; + + var license = new UserLicense + { + LicenseKey = "test_key", + Expires = DateTime.UtcNow.AddYears(1), + Token = "valid_token" + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[] + { + new Claim("Email", "test@example.com") + })); + + _licensingService.VerifyLicense(license).Returns(true); + _licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal); + + // Act + var result = await _command.Run(user, license); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Contains("The user's email is not verified.", badRequest.Response); + } + + [Fact] + public async Task Run_LicenseCannotBeUsed_EmailMismatch_ReturnsBadRequest() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Premium = false, + Email = "user@example.com", + EmailVerified = true + }; + + var license = new UserLicense + { + LicenseKey = "test_key", + Expires = DateTime.UtcNow.AddYears(1), + Token = "valid_token" + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[] + { + new Claim("Email", "license@example.com") + })); + + _licensingService.VerifyLicense(license).Returns(true); + _licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal); + + // Act + var result = await _command.Run(user, license); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Contains("The user's email does not match the license email.", badRequest.Response); + } + + [Fact] + public async Task Run_ValidRequest_Success() + { + // Arrange + var userId = Guid.NewGuid(); + var user = new User + { + Id = userId, + Premium = false, + Email = "test@example.com", + EmailVerified = true + }; + + var license = new UserLicense + { + LicenseKey = "test_key_12345", + Expires = DateTime.UtcNow.AddYears(1), + Token = "valid_token" + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity(new[] + { + new Claim("Email", "test@example.com") + })); + + _licensingService.VerifyLicense(license).Returns(true); + _licensingService.GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal); + + // Act + var result = await _command.Run(user, license); + + // Assert + Assert.True(result.IsT0); + + // Verify user was updated correctly + Assert.True(user.Premium); + Assert.NotNull(user.LicenseKey); + Assert.Equal(license.LicenseKey, user.LicenseKey); + Assert.NotEqual(default, user.RevisionDate); + + // Verify services were called + await _licensingService.Received(1).WriteUserLicenseAsync(user, license); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); + } +} diff --git a/test/Core.Test/Billing/Services/LicensingServiceTests.cs b/test/Core.Test/Billing/Services/LicensingServiceTests.cs index f33bda2164..cc160dec71 100644 --- a/test/Core.Test/Billing/Services/LicensingServiceTests.cs +++ b/test/Core.Test/Billing/Services/LicensingServiceTests.cs @@ -1,8 +1,10 @@ using System.Text.Json; using AutoFixture; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; +using Bit.Core.Entities; using Bit.Core.Settings; using Bit.Core.Test.Billing.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -16,6 +18,8 @@ public class LicensingServiceTests { private static string licenseFilePath(Guid orgId) => Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json"); + private static string userLicenseFilePath(Guid userId) => + Path.Combine(UserLicenseDirectory.Value, $"{userId}.json"); private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value); private static Lazy OrganizationLicenseDirectory => new(() => { @@ -26,6 +30,15 @@ public class LicensingServiceTests } return directory; }); + private static Lazy UserLicenseDirectory => new(() => + { + var directory = Path.Combine(Path.GetTempPath(), "user"); + if (!Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } + return directory; + }); public static SutProvider GetSutProvider() { @@ -57,4 +70,66 @@ public class LicensingServiceTests Directory.Delete(OrganizationLicenseDirectory.Value, true); } } + + [Theory, BitAutoData] + public async Task WriteUserLicense_CreatesFileWithCorrectContent(User user, UserLicense license) + { + // Arrange + var sutProvider = GetSutProvider(); + var expectedFilePath = userLicenseFilePath(user.Id); + + try + { + // Act + await sutProvider.Sut.WriteUserLicenseAsync(user, license); + + // Assert + Assert.True(File.Exists(expectedFilePath)); + var fileContent = await File.ReadAllTextAsync(expectedFilePath); + var actualLicense = JsonSerializer.Deserialize(fileContent); + + Assert.Equal(license.LicenseKey, actualLicense.LicenseKey); + Assert.Equal(license.Id, actualLicense.Id); + Assert.Equal(license.Expires, actualLicense.Expires); + } + finally + { + // Cleanup + if (Directory.Exists(UserLicenseDirectory.Value)) + { + Directory.Delete(UserLicenseDirectory.Value, true); + } + } + } + + [Theory, BitAutoData] + public async Task WriteUserLicense_CreatesDirectoryIfNotExists(User user, UserLicense license) + { + // Arrange + var sutProvider = GetSutProvider(); + + // Ensure directory doesn't exist + if (Directory.Exists(UserLicenseDirectory.Value)) + { + Directory.Delete(UserLicenseDirectory.Value, true); + } + + try + { + // Act + await sutProvider.Sut.WriteUserLicenseAsync(user, license); + + // Assert + Assert.True(Directory.Exists(UserLicenseDirectory.Value)); + Assert.True(File.Exists(userLicenseFilePath(user.Id))); + } + finally + { + // Cleanup + if (Directory.Exists(UserLicenseDirectory.Value)) + { + Directory.Delete(UserLicenseDirectory.Value, true); + } + } + } }