diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj index 116efdb68c..25327b17b7 100644 --- a/src/Billing/Billing.csproj +++ b/src/Billing/Billing.csproj @@ -6,6 +6,7 @@ + diff --git a/src/Billing/Services/IStripeFacade.cs b/src/Billing/Services/IStripeFacade.cs index 6886250a33..37ba51cc61 100644 --- a/src/Billing/Services/IStripeFacade.cs +++ b/src/Billing/Services/IStripeFacade.cs @@ -2,6 +2,7 @@ #nullable disable using Stripe; +using Stripe.TestHelpers; namespace Bit.Billing.Services; @@ -98,4 +99,10 @@ public interface IStripeFacade string subscriptionId, RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + + Task GetTestClock( + string testClockId, + TestClockGetOptions testClockGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); } diff --git a/src/Billing/Services/Implementations/StripeFacade.cs b/src/Billing/Services/Implementations/StripeFacade.cs index 70144d8cd3..726a3e977c 100644 --- a/src/Billing/Services/Implementations/StripeFacade.cs +++ b/src/Billing/Services/Implementations/StripeFacade.cs @@ -2,6 +2,8 @@ #nullable disable using Stripe; +using Stripe.TestHelpers; +using CustomerService = Stripe.CustomerService; namespace Bit.Billing.Services.Implementations; @@ -14,6 +16,7 @@ public class StripeFacade : IStripeFacade private readonly PaymentMethodService _paymentMethodService = new(); private readonly SubscriptionService _subscriptionService = new(); private readonly DiscountService _discountService = new(); + private readonly TestClockService _testClockService = new(); public async Task GetCharge( string chargeId, @@ -119,4 +122,11 @@ public class StripeFacade : IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default) => await _discountService.DeleteSubscriptionDiscountAsync(subscriptionId, requestOptions, cancellationToken); + + public Task GetTestClock( + string testClockId, + TestClockGetOptions testClockGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + _testClockService.GetAsync(testClockId, testClockGetOptions, requestOptions, cancellationToken); } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index fe5021c827..bbc17aa3b2 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,6 +1,9 @@ using Bit.Billing.Constants; using Bit.Billing.Jobs; +using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Platform.Push; @@ -8,6 +11,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Quartz; using Stripe; +using Stripe.TestHelpers; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -26,6 +30,10 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; + private readonly IFeatureService _featureService; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ILogger _logger; public SubscriptionUpdatedHandler( IStripeEventService stripeEventService, @@ -39,7 +47,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler ISchedulerFactory schedulerFactory, IOrganizationEnableCommand organizationEnableCommand, IOrganizationDisableCommand organizationDisableCommand, - IPricingClient pricingClient) + IPricingClient pricingClient, + IFeatureService featureService, + IProviderRepository providerRepository, + IProviderService providerService, + ILogger logger) { _stripeEventService = stripeEventService; _stripeEventUtilityService = stripeEventUtilityService; @@ -53,6 +65,10 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler _organizationEnableCommand = organizationEnableCommand; _organizationDisableCommand = organizationDisableCommand; _pricingClient = pricingClient; + _featureService = featureService; + _providerRepository = providerRepository; + _providerService = providerService; + _logger = logger; } /// @@ -61,7 +77,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler /// public async Task HandleAsync(Event parsedEvent) { - var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice"]); + var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice", "test_clock"]); var (organizationId, userId, providerId) = _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); switch (subscription.Status) @@ -77,6 +93,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } break; } + case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired when providerId.HasValue: + { + await HandleUnpaidProviderSubscriptionAsync(providerId.Value, parsedEvent, subscription); + break; + } case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired: { if (!userId.HasValue) @@ -238,4 +259,71 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler await scheduler.ScheduleJob(job, trigger); } + + private async Task HandleUnpaidProviderSubscriptionAsync( + Guid providerId, + Event parsedEvent, + Subscription subscription) + { + var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); + + if (!providerPortalTakeover) + { + return; + } + + var provider = await _providerRepository.GetByIdAsync(providerId); + if (provider == null) + { + return; + } + + try + { + provider.Enabled = false; + await _providerService.UpdateAsync(provider); + + if (parsedEvent.Data.PreviousAttributes != null) + { + if (parsedEvent.Data.PreviousAttributes.ToObject() as Subscription is + { + Status: + StripeSubscriptionStatus.Trialing or + StripeSubscriptionStatus.Active or + StripeSubscriptionStatus.PastDue + } && subscription is + { + Status: StripeSubscriptionStatus.Unpaid, + LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" + }) + { + if (subscription.TestClock != null) + { + await WaitForTestClockToAdvanceAsync(subscription.TestClock); + } + + var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + await _stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }); + } + } + } + catch (Exception exception) + { + _logger.LogError(exception, "An error occurred while trying to disable and schedule subscription cancellation for provider ({ProviderID})", providerId); + } + } + + private async Task WaitForTestClockToAdvanceAsync(TestClock testClock) + { + while (testClock.Status != "ready") + { + await Task.Delay(TimeSpan.FromSeconds(2)); + testClock = await _stripeFacade.GetTestClock(testClock.Id); + if (testClock.Status == "internal_failure") + { + throw new Exception("Stripe Test Clock encountered an internal failure"); + } + } + } } diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index 24b5372ba1..cfbc90c36e 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Net.Http.Headers; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; +using Bit.Commercial.Core.Utilities; using Bit.Core.Billing.Extensions; using Bit.Core.Context; using Bit.Core.SecretsManager.Repositories; @@ -83,6 +84,7 @@ public class Startup services.AddDefaultServices(globalSettings); services.AddDistributedCache(globalSettings); services.AddBillingOperations(); + services.AddCommercialCoreServices(); services.TryAddSingleton(); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 9573a0ce0a..08191ff356 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -159,6 +159,7 @@ public static class FeatureFlagKeys public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge"; public const string PM21383_GetProviderPriceFromStripe = "pm-21383-get-provider-price-from-stripe"; public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout"; + public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 9c58bbdbf7..ce4ee608cc 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -1,8 +1,12 @@ using Bit.Billing.Constants; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; +using Bit.Core; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; @@ -10,10 +14,12 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; +using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; using NSubstitute; using Quartz; using Stripe; +using Stripe.TestHelpers; using Xunit; using Event = Stripe.Event; @@ -33,6 +39,10 @@ public class SubscriptionUpdatedHandlerTests private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; + private readonly IFeatureService _featureService; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ILogger _logger; private readonly IScheduler _scheduler; private readonly SubscriptionUpdatedHandler _sut; @@ -50,6 +60,10 @@ public class SubscriptionUpdatedHandlerTests _organizationEnableCommand = Substitute.For(); _organizationDisableCommand = Substitute.For(); _pricingClient = Substitute.For(); + _featureService = Substitute.For(); + _providerRepository = Substitute.For(); + _providerService = Substitute.For(); + _logger = Substitute.For>(); _scheduler = Substitute.For(); _schedulerFactory.GetScheduler().Returns(_scheduler); @@ -66,7 +80,11 @@ public class SubscriptionUpdatedHandlerTests _schedulerFactory, _organizationEnableCommand, _organizationDisableCommand, - _pricingClient); + _pricingClient, + _featureService, + _providerRepository, + _providerService, + _logger); } [Fact] @@ -104,6 +122,300 @@ public class SubscriptionUpdatedHandlerTests Arg.Is(t => t.Key.Name == $"cancel-trigger-{subscriptionId}")); } + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() + { + // Arrange + var providerId = Guid.NewGuid(); + const string subscriptionId = "sub_123"; + var frozenTime = DateTime.UtcNow; + + var testClock = new TestClock + { + Id = "clock_123", + Status = "ready", + FrozenTime = frozenTime + }; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }, + TestClock = testClock + }; + + var provider = new Provider + { + Id = providerId, + Name = "Test Provider", + Enabled = true + }; + + var parsedEvent = new Event + { + Data = new EventData + { + PreviousAttributes = JObject.FromObject(new + { + status = "active" + }) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(true); + + _providerRepository.GetByIdAsync(providerId) + .Returns(provider); + + _stripeFacade.GetTestClock(testClock.Id) + .Returns(testClock); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _stripeFacade.Received(1).UpdateSubscription(subscriptionId, + Arg.Is(o => o.CancelAt == frozenTime.AddDays(7))); + } + + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WithoutValidTransition_DisablesProviderOnly() + { + // Arrange + var providerId = Guid.NewGuid(); + const string subscriptionId = "sub_123"; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + }; + + var provider = new Provider + { + Id = providerId, + Name = "Test Provider", + Enabled = true + }; + + var parsedEvent = new Event + { + Data = new EventData + { + PreviousAttributes = JObject.FromObject(new + { + status = "unpaid" // No valid transition + }) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(true); + + _providerRepository.GetByIdAsync(providerId) + .Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WithNoPreviousAttributes_DisablesProviderOnly() + { + // Arrange + var providerId = Guid.NewGuid(); + const string subscriptionId = "sub_123"; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + }; + + var provider = new Provider + { + Id = providerId, + Name = "Test Provider", + Enabled = true + }; + + var parsedEvent = new Event + { + Data = new EventData + { + PreviousAttributes = null + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(true); + + _providerRepository.GetByIdAsync(providerId) + .Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WithIncompleteExpiredStatus_DisablesProvider() + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.IncompleteExpired, + CurrentPeriodEnd = currentPeriodEnd, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "renewal" } + }; + + var provider = new Provider + { + Id = providerId, + Name = "Test Provider", + Enabled = true + }; + + var parsedEvent = new Event { Data = new EventData() }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(true); + + _providerRepository.GetByIdAsync(providerId) + .Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WhenFeatureFlagDisabled_DoesNothing() + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + CurrentPeriodEnd = currentPeriodEnd, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); + await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_DoesNothing() + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + CurrentPeriodEnd = currentPeriodEnd, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, null, providerId)); + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) + .Returns(true); + + _providerRepository.GetByIdAsync(providerId) + .Returns((Provider)null); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); + } + [Fact] public async Task HandleAsync_UnpaidUserSubscription_DisablesPremiumAndCancelsSubscription() {