diff --git a/src/Core/Billing/Payment/Models/PaymentMethod.cs b/src/Core/Billing/Payment/Models/PaymentMethod.cs index a6835f9a32..b0733da414 100644 --- a/src/Core/Billing/Payment/Models/PaymentMethod.cs +++ b/src/Core/Billing/Payment/Models/PaymentMethod.cs @@ -11,7 +11,9 @@ public class PaymentMethod(OneOf new(tokenized); public static implicit operator PaymentMethod(NonTokenizedPaymentMethod nonTokenized) => new(nonTokenized); public bool IsTokenized => IsT0; + public TokenizedPaymentMethod AsTokenized => AsT0; public bool IsNonTokenized => IsT1; + public NonTokenizedPaymentMethod AsNonTokenized => AsT1; } internal class PaymentMethodJsonConverter : JsonConverter diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 3b2ac5343f..1f752a007b 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -2,7 +2,9 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; @@ -21,6 +23,7 @@ using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Premium.Commands; +using static StripeConstants; using static Utilities; /// @@ -32,7 +35,7 @@ public interface ICreatePremiumCloudHostedSubscriptionCommand /// /// Creates a premium cloud-hosted subscription for the specified user. /// - /// The user to create the premium subscription for. Must not already be a premium user. + /// The user to create the premium subscription for. Must not yet be a premium user. /// The tokenized payment method containing the payment type and token for billing. /// The billing address information required for tax calculation and customer creation. /// Additional storage in GB beyond the base 1GB included with premium (must be >= 0). @@ -53,7 +56,9 @@ public class CreatePremiumCloudHostedSubscriptionCommand( IUserService userService, IPushNotificationService pushNotificationService, ILogger logger, - IPricingClient pricingClient) + IPricingClient pricingClient, + IHasPaymentMethodQuery hasPaymentMethodQuery, + IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingCommand(logger), ICreatePremiumCloudHostedSubscriptionCommand { private static readonly List _expand = ["tax"]; @@ -75,10 +80,30 @@ public class CreatePremiumCloudHostedSubscriptionCommand( return new BadRequest("Additional storage must be greater than 0."); } - // Note: A customer will already exist if the customer has purchased account credits. - var customer = string.IsNullOrEmpty(user.GatewayCustomerId) - ? await CreateCustomerAsync(user, paymentMethod, billingAddress) - : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + Customer? customer; + + /* + * For a new customer purchasing a new subscription, we attach the payment method while creating the customer. + */ + if (string.IsNullOrEmpty(user.GatewayCustomerId)) + { + customer = await CreateCustomerAsync(user, paymentMethod, billingAddress); + } + /* + * An existing customer without a payment method starting a new subscription indicates a user who previously + * purchased account credit but chose to use a tokenizable payment method to pay for the subscription. In this case, + * we need to add the payment method to their customer first. If the incoming payment method is account credit, + * we can just go straight to fetching the customer since there's no payment method to apply. + */ + else if (paymentMethod.IsTokenized && !await hasPaymentMethodQuery.Run(user)) + { + await updatePaymentMethodCommand.Run(user, paymentMethod.AsTokenized, billingAddress); + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } + else + { + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } customer = await ReconcileBillingLocationAsync(customer, billingAddress); @@ -91,9 +116,9 @@ public class CreatePremiumCloudHostedSubscriptionCommand( switch (tokenized) { case { Type: TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete: + when subscription.Status == SubscriptionStatus.Incomplete: case { Type: not TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Active: + when subscription.Status == SubscriptionStatus.Active: { user.Premium = true; user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); @@ -101,13 +126,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } } }, - nonTokenized => + _ => { - if (subscription.Status == StripeConstants.SubscriptionStatus.Active) + if (subscription.Status != SubscriptionStatus.Active) { - user.Premium = true; - user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + return; } + + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); }); user.Gateway = GatewayType.Stripe; @@ -163,25 +190,25 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, - [StripeConstants.MetadataKeys.UserId] = user.Id.ToString() + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [MetadataKeys.UserId] = user.Id.ToString() }, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; var braintreeCustomerId = ""; // We have checked that the payment method is tokenized, so we can safely cast it. - // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault - switch (paymentMethod.AsT0.Type) + var tokenizedPaymentMethod = paymentMethod.AsTokenized; + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.AsT0.Token })) + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -195,19 +222,19 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } case TokenizablePaymentMethodType.Card: { - customerCreateOptions.PaymentMethod = paymentMethod.AsT0.Token; - customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.AsT0.Token; + customerCreateOptions.PaymentMethod = tokenizedPaymentMethod.Token; + customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = tokenizedPaymentMethod.Token; break; } case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.AsT0.Token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, tokenizedPaymentMethod.Token); customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } default: { - _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.AsT0.Type.ToString()); + _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, tokenizedPaymentMethod.Type.ToString()); throw new BillingException(); } } @@ -225,21 +252,18 @@ public class CreatePremiumCloudHostedSubscriptionCommand( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - if (paymentMethod.IsTokenized) + switch (tokenizedPaymentMethod.Type) { - switch (paymentMethod.AsT0.Type) - { - case TokenizablePaymentMethodType.BankAccount: - { - await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id); - break; - } - case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): - { - await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); - break; - } - } + case TokenizablePaymentMethodType.BankAccount: + { + await setupIntentCache.RemoveSetupIntentForSubscriber(user.Id); + break; + } + case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + { + await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); + break; + } } } } @@ -271,7 +295,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( Expand = _expand, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); @@ -310,15 +334,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { Enabled = true }, - CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + CollectionMethod = CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.UserId] = userId.ToString() + [MetadataKeys.UserId] = userId.ToString() }, PaymentBehavior = usingPayPal - ? StripeConstants.PaymentBehavior.DefaultIncomplete + ? PaymentBehavior.DefaultIncomplete : null, OffSession = true }; diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index c0618f78ed..493246c578 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -2,7 +2,9 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; @@ -34,6 +36,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests private readonly IUserService _userService = Substitute.For(); private readonly IPushNotificationService _pushNotificationService = Substitute.For(); private readonly IPricingClient _pricingClient = Substitute.For(); + private readonly IHasPaymentMethodQuery _hasPaymentMethodQuery = Substitute.For(); + private readonly IUpdatePaymentMethodCommand _updatePaymentMethodCommand = Substitute.For(); private readonly CreatePremiumCloudHostedSubscriptionCommand _command; public CreatePremiumCloudHostedSubscriptionCommandTests() @@ -62,7 +66,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests _userService, _pushNotificationService, Substitute.For>(), - _pricingClient); + _pricingClient, + _hasPaymentMethodQuery, + _updatePaymentMethodCommand); } [Theory, BitAutoData] @@ -314,7 +320,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests } [Theory, BitAutoData] - public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer( + public async Task Run_UserHasExistingGatewayCustomerIdAndPaymentMethod_UsesExistingCustomer( User user, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) @@ -347,6 +353,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); + // Mock that the user has a payment method (this is the key difference from the credit purchase case) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(true); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); @@ -358,6 +366,75 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_UserPreviouslyPurchasedCreditWithoutPaymentMethod_UpdatesPaymentMethodAndCreatesSubscription( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = "existing_customer_123"; // Customer exists from previous credit purchase + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "existing_customer_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + var mockInvoice = Substitute.For(); + MaskedPaymentMethod mockMaskedPaymentMethod = new MaskedCard + { + Brand = "visa", + Last4 = "1234", + Expiration = "12/2025" + }; + + // Mock that the user does NOT have a payment method (simulating credit purchase scenario) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(false); + _updatePaymentMethodCommand.Run(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(mockMaskedPaymentMethod); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + // Verify that update payment method was called (new behavior for credit purchase case) + await _updatePaymentMethodCommand.Received(1).Run(user, paymentMethod, billingAddress); + // Verify GetCustomerOrThrow was called after updating payment method + await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); + // Verify no new customer was created + await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + // Verify subscription was created + await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + // Verify user was updated correctly + Assert.True(user.Premium); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } [Theory, BitAutoData]