diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj
index 116efdb68c..25327b17b7 100644
--- a/src/Billing/Billing.csproj
+++ b/src/Billing/Billing.csproj
@@ -6,6 +6,7 @@
+
diff --git a/src/Billing/Services/IStripeFacade.cs b/src/Billing/Services/IStripeFacade.cs
index 6886250a33..37ba51cc61 100644
--- a/src/Billing/Services/IStripeFacade.cs
+++ b/src/Billing/Services/IStripeFacade.cs
@@ -2,6 +2,7 @@
#nullable disable
using Stripe;
+using Stripe.TestHelpers;
namespace Bit.Billing.Services;
@@ -98,4 +99,10 @@ public interface IStripeFacade
string subscriptionId,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
+
+ Task GetTestClock(
+ string testClockId,
+ TestClockGetOptions testClockGetOptions = null,
+ RequestOptions requestOptions = null,
+ CancellationToken cancellationToken = default);
}
diff --git a/src/Billing/Services/Implementations/StripeFacade.cs b/src/Billing/Services/Implementations/StripeFacade.cs
index 70144d8cd3..726a3e977c 100644
--- a/src/Billing/Services/Implementations/StripeFacade.cs
+++ b/src/Billing/Services/Implementations/StripeFacade.cs
@@ -2,6 +2,8 @@
#nullable disable
using Stripe;
+using Stripe.TestHelpers;
+using CustomerService = Stripe.CustomerService;
namespace Bit.Billing.Services.Implementations;
@@ -14,6 +16,7 @@ public class StripeFacade : IStripeFacade
private readonly PaymentMethodService _paymentMethodService = new();
private readonly SubscriptionService _subscriptionService = new();
private readonly DiscountService _discountService = new();
+ private readonly TestClockService _testClockService = new();
public async Task GetCharge(
string chargeId,
@@ -119,4 +122,11 @@ public class StripeFacade : IStripeFacade
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default) =>
await _discountService.DeleteSubscriptionDiscountAsync(subscriptionId, requestOptions, cancellationToken);
+
+ public Task GetTestClock(
+ string testClockId,
+ TestClockGetOptions testClockGetOptions = null,
+ RequestOptions requestOptions = null,
+ CancellationToken cancellationToken = default) =>
+ _testClockService.GetAsync(testClockId, testClockGetOptions, requestOptions, cancellationToken);
}
diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs
index fe5021c827..bbc17aa3b2 100644
--- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs
+++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs
@@ -1,6 +1,9 @@
using Bit.Billing.Constants;
using Bit.Billing.Jobs;
+using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Pricing;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Platform.Push;
@@ -8,6 +11,7 @@ using Bit.Core.Repositories;
using Bit.Core.Services;
using Quartz;
using Stripe;
+using Stripe.TestHelpers;
using Event = Stripe.Event;
namespace Bit.Billing.Services.Implementations;
@@ -26,6 +30,10 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient;
+ private readonly IFeatureService _featureService;
+ private readonly IProviderRepository _providerRepository;
+ private readonly IProviderService _providerService;
+ private readonly ILogger _logger;
public SubscriptionUpdatedHandler(
IStripeEventService stripeEventService,
@@ -39,7 +47,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
ISchedulerFactory schedulerFactory,
IOrganizationEnableCommand organizationEnableCommand,
IOrganizationDisableCommand organizationDisableCommand,
- IPricingClient pricingClient)
+ IPricingClient pricingClient,
+ IFeatureService featureService,
+ IProviderRepository providerRepository,
+ IProviderService providerService,
+ ILogger logger)
{
_stripeEventService = stripeEventService;
_stripeEventUtilityService = stripeEventUtilityService;
@@ -53,6 +65,10 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
_organizationEnableCommand = organizationEnableCommand;
_organizationDisableCommand = organizationDisableCommand;
_pricingClient = pricingClient;
+ _featureService = featureService;
+ _providerRepository = providerRepository;
+ _providerService = providerService;
+ _logger = logger;
}
///
@@ -61,7 +77,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
///
public async Task HandleAsync(Event parsedEvent)
{
- var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice"]);
+ var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice", "test_clock"]);
var (organizationId, userId, providerId) = _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata);
switch (subscription.Status)
@@ -77,6 +93,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
}
break;
}
+ case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired when providerId.HasValue:
+ {
+ await HandleUnpaidProviderSubscriptionAsync(providerId.Value, parsedEvent, subscription);
+ break;
+ }
case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired:
{
if (!userId.HasValue)
@@ -238,4 +259,71 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
await scheduler.ScheduleJob(job, trigger);
}
+
+ private async Task HandleUnpaidProviderSubscriptionAsync(
+ Guid providerId,
+ Event parsedEvent,
+ Subscription subscription)
+ {
+ var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover);
+
+ if (!providerPortalTakeover)
+ {
+ return;
+ }
+
+ var provider = await _providerRepository.GetByIdAsync(providerId);
+ if (provider == null)
+ {
+ return;
+ }
+
+ try
+ {
+ provider.Enabled = false;
+ await _providerService.UpdateAsync(provider);
+
+ if (parsedEvent.Data.PreviousAttributes != null)
+ {
+ if (parsedEvent.Data.PreviousAttributes.ToObject() as Subscription is
+ {
+ Status:
+ StripeSubscriptionStatus.Trialing or
+ StripeSubscriptionStatus.Active or
+ StripeSubscriptionStatus.PastDue
+ } && subscription is
+ {
+ Status: StripeSubscriptionStatus.Unpaid,
+ LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create"
+ })
+ {
+ if (subscription.TestClock != null)
+ {
+ await WaitForTestClockToAdvanceAsync(subscription.TestClock);
+ }
+
+ var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow;
+ await _stripeFacade.UpdateSubscription(subscription.Id,
+ new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) });
+ }
+ }
+ }
+ catch (Exception exception)
+ {
+ _logger.LogError(exception, "An error occurred while trying to disable and schedule subscription cancellation for provider ({ProviderID})", providerId);
+ }
+ }
+
+ private async Task WaitForTestClockToAdvanceAsync(TestClock testClock)
+ {
+ while (testClock.Status != "ready")
+ {
+ await Task.Delay(TimeSpan.FromSeconds(2));
+ testClock = await _stripeFacade.GetTestClock(testClock.Id);
+ if (testClock.Status == "internal_failure")
+ {
+ throw new Exception("Stripe Test Clock encountered an internal failure");
+ }
+ }
+ }
}
diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs
index 24b5372ba1..cfbc90c36e 100644
--- a/src/Billing/Startup.cs
+++ b/src/Billing/Startup.cs
@@ -5,6 +5,7 @@ using System.Globalization;
using System.Net.Http.Headers;
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
+using Bit.Commercial.Core.Utilities;
using Bit.Core.Billing.Extensions;
using Bit.Core.Context;
using Bit.Core.SecretsManager.Repositories;
@@ -83,6 +84,7 @@ public class Startup
services.AddDefaultServices(globalSettings);
services.AddDistributedCache(globalSettings);
services.AddBillingOperations();
+ services.AddCommercialCoreServices();
services.TryAddSingleton();
diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs
index 9573a0ce0a..08191ff356 100644
--- a/src/Core/Constants.cs
+++ b/src/Core/Constants.cs
@@ -159,6 +159,7 @@ public static class FeatureFlagKeys
public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge";
public const string PM21383_GetProviderPriceFromStripe = "pm-21383-get-provider-price-from-stripe";
public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout";
+ public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover";
/* Key Management Team */
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs
index 9c58bbdbf7..ce4ee608cc 100644
--- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs
+++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs
@@ -1,8 +1,12 @@
using Bit.Billing.Constants;
using Bit.Billing.Services;
using Bit.Billing.Services.Implementations;
+using Bit.Core;
using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models.StaticStore.Plans;
using Bit.Core.Billing.Pricing;
@@ -10,10 +14,12 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
+using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using NSubstitute;
using Quartz;
using Stripe;
+using Stripe.TestHelpers;
using Xunit;
using Event = Stripe.Event;
@@ -33,6 +39,10 @@ public class SubscriptionUpdatedHandlerTests
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient;
+ private readonly IFeatureService _featureService;
+ private readonly IProviderRepository _providerRepository;
+ private readonly IProviderService _providerService;
+ private readonly ILogger _logger;
private readonly IScheduler _scheduler;
private readonly SubscriptionUpdatedHandler _sut;
@@ -50,6 +60,10 @@ public class SubscriptionUpdatedHandlerTests
_organizationEnableCommand = Substitute.For();
_organizationDisableCommand = Substitute.For();
_pricingClient = Substitute.For();
+ _featureService = Substitute.For();
+ _providerRepository = Substitute.For();
+ _providerService = Substitute.For();
+ _logger = Substitute.For>();
_scheduler = Substitute.For();
_schedulerFactory.GetScheduler().Returns(_scheduler);
@@ -66,7 +80,11 @@ public class SubscriptionUpdatedHandlerTests
_schedulerFactory,
_organizationEnableCommand,
_organizationDisableCommand,
- _pricingClient);
+ _pricingClient,
+ _featureService,
+ _providerRepository,
+ _providerService,
+ _logger);
}
[Fact]
@@ -104,6 +122,300 @@ public class SubscriptionUpdatedHandlerTests
Arg.Is(t => t.Key.Name == $"cancel-trigger-{subscriptionId}"));
}
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ const string subscriptionId = "sub_123";
+ var frozenTime = DateTime.UtcNow;
+
+ var testClock = new TestClock
+ {
+ Id = "clock_123",
+ Status = "ready",
+ FrozenTime = frozenTime
+ };
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.Unpaid,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "subscription_cycle" },
+ TestClock = testClock
+ };
+
+ var provider = new Provider
+ {
+ Id = providerId,
+ Name = "Test Provider",
+ Enabled = true
+ };
+
+ var parsedEvent = new Event
+ {
+ Data = new EventData
+ {
+ PreviousAttributes = JObject.FromObject(new
+ {
+ status = "active"
+ })
+ }
+ };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(true);
+
+ _providerRepository.GetByIdAsync(providerId)
+ .Returns(provider);
+
+ _stripeFacade.GetTestClock(testClock.Id)
+ .Returns(testClock);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ Assert.False(provider.Enabled);
+ await _providerService.Received(1).UpdateAsync(provider);
+ await _stripeFacade.Received(1).UpdateSubscription(subscriptionId,
+ Arg.Is(o => o.CancelAt == frozenTime.AddDays(7)));
+ }
+
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WithoutValidTransition_DisablesProviderOnly()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ const string subscriptionId = "sub_123";
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.Unpaid,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }
+ };
+
+ var provider = new Provider
+ {
+ Id = providerId,
+ Name = "Test Provider",
+ Enabled = true
+ };
+
+ var parsedEvent = new Event
+ {
+ Data = new EventData
+ {
+ PreviousAttributes = JObject.FromObject(new
+ {
+ status = "unpaid" // No valid transition
+ })
+ }
+ };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(true);
+
+ _providerRepository.GetByIdAsync(providerId)
+ .Returns(provider);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ Assert.False(provider.Enabled);
+ await _providerService.Received(1).UpdateAsync(provider);
+ await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any());
+ }
+
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WithNoPreviousAttributes_DisablesProviderOnly()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ const string subscriptionId = "sub_123";
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.Unpaid,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }
+ };
+
+ var provider = new Provider
+ {
+ Id = providerId,
+ Name = "Test Provider",
+ Enabled = true
+ };
+
+ var parsedEvent = new Event
+ {
+ Data = new EventData
+ {
+ PreviousAttributes = null
+ }
+ };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(true);
+
+ _providerRepository.GetByIdAsync(providerId)
+ .Returns(provider);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ Assert.False(provider.Enabled);
+ await _providerService.Received(1).UpdateAsync(provider);
+ await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any());
+ }
+
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WithIncompleteExpiredStatus_DisablesProvider()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ var subscriptionId = "sub_123";
+ var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.IncompleteExpired,
+ CurrentPeriodEnd = currentPeriodEnd,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "renewal" }
+ };
+
+ var provider = new Provider
+ {
+ Id = providerId,
+ Name = "Test Provider",
+ Enabled = true
+ };
+
+ var parsedEvent = new Event { Data = new EventData() };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(true);
+
+ _providerRepository.GetByIdAsync(providerId)
+ .Returns(provider);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ Assert.False(provider.Enabled);
+ await _providerService.Received(1).UpdateAsync(provider);
+ await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any());
+ }
+
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WhenFeatureFlagDisabled_DoesNothing()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ var subscriptionId = "sub_123";
+ var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.Unpaid,
+ CurrentPeriodEnd = currentPeriodEnd,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }
+ };
+
+ var parsedEvent = new Event { Data = new EventData() };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(false);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any());
+ await _providerService.DidNotReceive().UpdateAsync(Arg.Any());
+ }
+
+ [Fact]
+ public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_DoesNothing()
+ {
+ // Arrange
+ var providerId = Guid.NewGuid();
+ var subscriptionId = "sub_123";
+ var currentPeriodEnd = DateTime.UtcNow.AddDays(30);
+
+ var subscription = new Subscription
+ {
+ Id = subscriptionId,
+ Status = StripeSubscriptionStatus.Unpaid,
+ CurrentPeriodEnd = currentPeriodEnd,
+ Metadata = new Dictionary { { "providerId", providerId.ToString() } },
+ LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }
+ };
+
+ var parsedEvent = new Event { Data = new EventData() };
+
+ _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>())
+ .Returns(subscription);
+
+ _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>())
+ .Returns(Tuple.Create(null, null, providerId));
+
+ _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover)
+ .Returns(true);
+
+ _providerRepository.GetByIdAsync(providerId)
+ .Returns((Provider)null);
+
+ // Act
+ await _sut.HandleAsync(parsedEvent);
+
+ // Assert
+ await _providerService.DidNotReceive().UpdateAsync(Arg.Any());
+ await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any());
+ }
+
[Fact]
public async Task HandleAsync_UnpaidUserSubscription_DisablesPremiumAndCancelsSubscription()
{