diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 702d9aaf3d..d5fcfb20d4 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,4 +1,5 @@ -using Bit.Billing.Constants; +using System.Globalization; +using Bit.Billing.Constants; using Bit.Billing.Jobs; using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; @@ -316,7 +317,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler private async Task HandleUnpaidProviderSubscriptionAsync( Guid providerId, Event parsedEvent, - Subscription subscription) + Subscription currentSubscription) { var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); @@ -338,26 +339,43 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler if (parsedEvent.Data.PreviousAttributes != null) { - if (parsedEvent.Data.PreviousAttributes.ToObject() as Subscription is - { - Status: - StripeSubscriptionStatus.Trialing or - StripeSubscriptionStatus.Active or - StripeSubscriptionStatus.PastDue - } && subscription is - { - Status: StripeSubscriptionStatus.Unpaid, - LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" - }) + var previousSubscription = parsedEvent.Data.PreviousAttributes.ToObject() as Subscription; + + var updateIsSubscriptionGoingUnpaid = previousSubscription is { - if (subscription.TestClock != null) + Status: + StripeSubscriptionStatus.Trialing or + StripeSubscriptionStatus.Active or + StripeSubscriptionStatus.PastDue + } && currentSubscription is + { + Status: StripeSubscriptionStatus.Unpaid, + LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" + }; + + var updateIsManualSuspensionViaMetadata = CheckForManualSuspensionViaMetadata( + previousSubscription, currentSubscription); + + if (updateIsSubscriptionGoingUnpaid || updateIsManualSuspensionViaMetadata) + { + if (currentSubscription.TestClock != null) { - await WaitForTestClockToAdvanceAsync(subscription.TestClock); + await WaitForTestClockToAdvanceAsync(currentSubscription.TestClock); } - var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; - await _stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }); + var now = currentSubscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + + var subscriptionUpdateOptions = new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }; + + if (updateIsManualSuspensionViaMetadata) + { + subscriptionUpdateOptions.Metadata = new Dictionary + { + ["suspended_provider_via_webhook_at"] = DateTime.UtcNow.ToString(CultureInfo.InvariantCulture) + }; + } + + await _stripeFacade.UpdateSubscription(currentSubscription.Id, subscriptionUpdateOptions); } } } @@ -379,4 +397,37 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } } + + private static bool CheckForManualSuspensionViaMetadata( + Subscription? previousSubscription, + Subscription currentSubscription) + { + /* + * When metadata on a subscription is updated, we'll receive an event that has: + * Previous Metadata: { newlyAddedKey: null } + * Current Metadata: { newlyAddedKey: newlyAddedValue } + * + * As such, our check for a manual suspension must ensure that the 'previous_attributes' does contain the + * 'metadata' property, but also that the "suspend_provider" key in that metadata is set to null. + * + * If we don't do this and instead do a null coalescing check on 'previous_attributes?.metadata?.TryGetValue', + * we'll end up marking an event where 'previous_attributes.metadata' = null (which could be any subscription update + * that does not update the metadata) the same as a manual suspension. + */ + const string key = "suspend_provider"; + + if (previousSubscription is not { Metadata: not null } || + !previousSubscription.Metadata.TryGetValue(key, out var previousValue)) + { + return false; + } + + if (previousValue == null) + { + return !string.IsNullOrEmpty( + currentSubscription.Metadata.TryGetValue(key, out var currentValue) ? currentValue : null); + } + + return false; + } } diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index f230b87dea..0d1f54ecfd 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -20,7 +20,6 @@ using NSubstitute; using NSubstitute.ReturnsExtensions; using Quartz; using Stripe; -using Stripe.TestHelpers; using Xunit; using Event = Stripe.Event; @@ -36,14 +35,12 @@ public class SubscriptionUpdatedHandlerTests private readonly IUserService _userService; private readonly IPushNotificationService _pushNotificationService; private readonly IOrganizationRepository _organizationRepository; - private readonly ISchedulerFactory _schedulerFactory; private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; private readonly IFeatureService _featureService; private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; - private readonly ILogger _logger; private readonly IScheduler _scheduler; private readonly SubscriptionUpdatedHandler _sut; @@ -58,18 +55,17 @@ public class SubscriptionUpdatedHandlerTests _providerService = Substitute.For(); _pushNotificationService = Substitute.For(); _organizationRepository = Substitute.For(); - _providerRepository = Substitute.For(); - _schedulerFactory = Substitute.For(); + var schedulerFactory = Substitute.For(); _organizationEnableCommand = Substitute.For(); _organizationDisableCommand = Substitute.For(); _pricingClient = Substitute.For(); _featureService = Substitute.For(); _providerRepository = Substitute.For(); _providerService = Substitute.For(); - _logger = Substitute.For>(); + var logger = Substitute.For>(); _scheduler = Substitute.For(); - _schedulerFactory.GetScheduler().Returns(_scheduler); + schedulerFactory.GetScheduler().Returns(_scheduler); _sut = new SubscriptionUpdatedHandler( _stripeEventService, @@ -80,14 +76,14 @@ public class SubscriptionUpdatedHandlerTests _userService, _pushNotificationService, _organizationRepository, - _schedulerFactory, + schedulerFactory, _organizationEnableCommand, _organizationDisableCommand, _pricingClient, _featureService, _providerRepository, _providerService, - _logger); + logger); } [Fact] @@ -126,61 +122,54 @@ public class SubscriptionUpdatedHandlerTests } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() + public async Task + HandleAsync_UnpaidProviderSubscription_WithManualSuspensionViaMetadata_DisablesProviderAndSchedulesCancellation() { // Arrange var providerId = Guid.NewGuid(); - const string subscriptionId = "sub_123"; - var frozenTime = DateTime.UtcNow; + var subscriptionId = "sub_test123"; - var testClock = new TestClock + var previousSubscription = new Subscription { - Id = "clock_123", - Status = "ready", - FrozenTime = frozenTime + Id = subscriptionId, + Status = StripeSubscriptionStatus.Active, + Metadata = new Dictionary + { + ["suspend_provider"] = null // This is the key part - metadata exists, but value is null + } }; - var subscription = new Subscription + var currentSubscription = new Subscription { Id = subscriptionId, Status = StripeSubscriptionStatus.Unpaid, - Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }, - TestClock = testClock - }; - - var provider = new Provider - { - Id = providerId, - Name = "Test Provider", - Enabled = true + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), + Metadata = new Dictionary + { + ["providerId"] = providerId.ToString(), + ["suspend_provider"] = "true" // Now has a value, indicating manual suspension + }, + TestClock = null }; var parsedEvent = new Event { + Id = "evt_test123", + Type = HandledStripeWebhook.SubscriptionUpdated, Data = new EventData { - PreviousAttributes = JObject.FromObject(new - { - status = "active" - }) + Object = currentSubscription, + PreviousAttributes = JObject.FromObject(previousSubscription) } }; - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); + var provider = new Provider { Id = providerId, Enabled = true }; - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); + _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); + _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) .Returns(Tuple.Create(null, null, providerId)); - - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - - _providerRepository.GetByIdAsync(providerId) - .Returns(provider); - - _stripeFacade.GetTestClock(testClock.Id) - .Returns(testClock); + _providerRepository.GetByIdAsync(providerId).Returns(provider); // Act await _sut.HandleAsync(parsedEvent); @@ -188,8 +177,75 @@ public class SubscriptionUpdatedHandlerTests // Assert Assert.False(provider.Enabled); await _providerService.Received(1).UpdateAsync(provider); - await _stripeFacade.Received(1).UpdateSubscription(subscriptionId, - Arg.Is(o => o.CancelAt == frozenTime.AddDays(7))); + + // Verify that UpdateSubscription was called with both CancelAt and the new metadata + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAt.HasValue && + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + options.Metadata != null && + options.Metadata.ContainsKey("suspended_provider_via_webhook_at"))); + } + + [Fact] + public async Task + HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() + { + // Arrange + var providerId = Guid.NewGuid(); + var subscriptionId = "sub_test123"; + + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Active, + Metadata = new Dictionary { ["providerId"] = providerId.ToString() } + }; + + var currentSubscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), + Metadata = new Dictionary { ["providerId"] = providerId.ToString() }, + LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }, + TestClock = null + }; + + var parsedEvent = new Event + { + Id = "evt_test123", + Type = HandledStripeWebhook.SubscriptionUpdated, + Data = new EventData + { + Object = currentSubscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; + + var provider = new Provider { Id = providerId, Enabled = true }; + + _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); + _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); + _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + + // Verify that UpdateSubscription was called with CancelAt but WITHOUT suspension metadata + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAt.HasValue && + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + (options.Metadata == null || !options.Metadata.ContainsKey("suspended_provider_via_webhook_at")))); } [Fact] @@ -207,12 +263,7 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } }; - var provider = new Provider - { - Id = providerId, - Name = "Test Provider", - Enabled = true - }; + var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; var parsedEvent = new Event { @@ -220,7 +271,7 @@ public class SubscriptionUpdatedHandlerTests { PreviousAttributes = JObject.FromObject(new { - status = "unpaid" // No valid transition + status = "unpaid" // No valid transition }) } }; @@ -261,20 +312,9 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } }; - var provider = new Provider - { - Id = providerId, - Name = "Test Provider", - Enabled = true - }; + var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event - { - Data = new EventData - { - PreviousAttributes = null - } - }; + var parsedEvent = new Event { Data = new EventData { PreviousAttributes = null } }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); @@ -314,12 +354,7 @@ public class SubscriptionUpdatedHandlerTests LatestInvoice = new Invoice { BillingReason = "renewal" } }; - var provider = new Provider - { - Id = providerId, - Name = "Test Provider", - Enabled = true - }; + var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; var parsedEvent = new Event { Data = new EventData() }; @@ -434,10 +469,10 @@ public class SubscriptionUpdatedHandlerTests Metadata = new Dictionary { { "userId", userId.ToString() } }, Items = new StripeList { - Data = new List - { - new() { Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } } - } + Data = + [ + new SubscriptionItem { Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } } + ] } }; @@ -478,11 +513,7 @@ public class SubscriptionUpdatedHandlerTests Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; - var organization = new Organization - { - Id = organizationId, - PlanType = PlanType.EnterpriseAnnually2023 - }; + var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 }; var parsedEvent = new Event { Data = new EventData() }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) @@ -495,7 +526,7 @@ public class SubscriptionUpdatedHandlerTests .Returns(organization); _stripeFacade.ListInvoices(Arg.Any()) - .Returns(new StripeList { Data = new List { new Invoice { Id = "inv_123" } } }); + .Returns(new StripeList { Data = [new Invoice { Id = "inv_123" }] }); var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) @@ -577,7 +608,8 @@ public class SubscriptionUpdatedHandlerTests } [Fact] - public async Task HandleAsync_WhenSubscriptionIsActive_AndOrganizationHasSecretsManagerTrial_AndRemovingSecretsManagerTrial_RemovesPasswordManagerCoupon() + public async Task + HandleAsync_WhenSubscriptionIsActive_AndOrganizationHasSecretsManagerTrial_AndRemovingSecretsManagerTrial_RemovesPasswordManagerCoupon() { // Arrange var organizationId = Guid.NewGuid(); @@ -589,34 +621,18 @@ public class SubscriptionUpdatedHandlerTests CustomerId = "cus_123", Items = new StripeList { - Data = new List - { - new() { Plan = new Stripe.Plan { Id = "2023-enterprise-org-seat-annually" } } - } + Data = [new SubscriptionItem { Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } }] }, Customer = new Customer { Balance = 0, - Discount = new Discount - { - Coupon = new Coupon { Id = "sm-standalone" } - } + Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } } }, - Discount = new Discount - { - Coupon = new Coupon { Id = "sm-standalone" } - }, - Metadata = new Dictionary - { - { "organizationId", organizationId.ToString() } - } + Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } }, + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; - var organization = new Organization - { - Id = organizationId, - PlanType = PlanType.EnterpriseAnnually2023 - }; + var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 }; var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) @@ -631,20 +647,14 @@ public class SubscriptionUpdatedHandlerTests { items = new { - data = new[] - { - new { plan = new { id = "secrets-manager-enterprise-seat-annually" } } - } + data = new[] { new { plan = new { id = "secrets-manager-enterprise-seat-annually" } } } }, Items = new StripeList { - Data = new List - { - new SubscriptionItem - { - Plan = new Stripe.Plan { Id = "secrets-manager-enterprise-seat-annually" } - } - } + Data = + [ + new SubscriptionItem { Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } } + ] } }) } @@ -990,7 +1000,7 @@ public class SubscriptionUpdatedHandlerTests { Id = previousSubscription?.Id ?? "sub_123", Status = StripeSubscriptionStatus.Active, - Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } }; var provider = new Provider { Id = providerId, Enabled = false }; @@ -1010,10 +1020,10 @@ public class SubscriptionUpdatedHandlerTests { return new List { - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Unpaid }, }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Incomplete }, }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.IncompleteExpired }, }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Paused }, }, + new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Unpaid } }, + new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Incomplete } }, + new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.IncompleteExpired } }, + new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Paused } } }; } }