diff --git a/src/Billing/Jobs/ProviderOrganizationDisableJob.cs b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs new file mode 100644 index 0000000000..5a48dd609f --- /dev/null +++ b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs @@ -0,0 +1,88 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Quartz; + +namespace Bit.Billing.Jobs; + +public class ProviderOrganizationDisableJob( + IProviderOrganizationRepository providerOrganizationRepository, + IOrganizationDisableCommand organizationDisableCommand, + ILogger logger) + : IJob +{ + private const int MaxConcurrency = 5; + private const int MaxTimeoutMinutes = 10; + + public async Task Execute(IJobExecutionContext context) + { + var providerId = new Guid(context.MergedJobDataMap.GetString("providerId") ?? string.Empty); + var expirationDateString = context.MergedJobDataMap.GetString("expirationDate"); + DateTime? expirationDate = string.IsNullOrEmpty(expirationDateString) + ? null + : DateTime.Parse(expirationDateString); + + logger.LogInformation("Starting to disable organizations for provider {ProviderId}", providerId); + + var startTime = DateTime.UtcNow; + var totalProcessed = 0; + var totalErrors = 0; + + try + { + var providerOrganizations = await providerOrganizationRepository + .GetManyDetailsByProviderAsync(providerId); + + if (providerOrganizations == null || !providerOrganizations.Any()) + { + logger.LogInformation("No organizations found for provider {ProviderId}", providerId); + return; + } + + logger.LogInformation("Disabling {OrganizationCount} organizations for provider {ProviderId}", + providerOrganizations.Count, providerId); + + var semaphore = new SemaphoreSlim(MaxConcurrency, MaxConcurrency); + var tasks = providerOrganizations.Select(async po => + { + if (DateTime.UtcNow.Subtract(startTime).TotalMinutes > MaxTimeoutMinutes) + { + logger.LogWarning("Timeout reached while disabling organizations for provider {ProviderId}", providerId); + return false; + } + + await semaphore.WaitAsync(); + try + { + await organizationDisableCommand.DisableAsync(po.OrganizationId, expirationDate); + Interlocked.Increment(ref totalProcessed); + return true; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to disable organization {OrganizationId} for provider {ProviderId}", + po.OrganizationId, providerId); + Interlocked.Increment(ref totalErrors); + return false; + } + finally + { + semaphore.Release(); + } + }); + + await Task.WhenAll(tasks); + + logger.LogInformation("Completed disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + } + catch (Exception ex) + { + logger.LogError(ex, "Error disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + throw; + } + } +} diff --git a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs index 13adf9825d..c204cc5026 100644 --- a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs @@ -1,7 +1,11 @@ using Bit.Billing.Constants; +using Bit.Billing.Jobs; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Extensions; using Bit.Core.Services; +using Quartz; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -11,17 +15,26 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler private readonly IUserService _userService; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ISchedulerFactory _schedulerFactory; public SubscriptionDeletedHandler( IStripeEventService stripeEventService, IUserService userService, IStripeEventUtilityService stripeEventUtilityService, - IOrganizationDisableCommand organizationDisableCommand) + IOrganizationDisableCommand organizationDisableCommand, + IProviderRepository providerRepository, + IProviderService providerService, + ISchedulerFactory schedulerFactory) { _stripeEventService = stripeEventService; _userService = userService; _stripeEventUtilityService = stripeEventUtilityService; _organizationDisableCommand = organizationDisableCommand; + _providerRepository = providerRepository; + _providerService = providerService; + _schedulerFactory = schedulerFactory; } /// @@ -53,9 +66,38 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.GetCurrentPeriodEnd()); } + else if (providerId.HasValue) + { + var provider = await _providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = false; + await _providerService.UpdateAsync(provider); + + await QueueProviderOrganizationDisableJobAsync(providerId.Value, subscription.GetCurrentPeriodEnd()); + } + } else if (userId.HasValue) { await _userService.DisablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); } } + + private async Task QueueProviderOrganizationDisableJobAsync(Guid providerId, DateTime? expirationDate) + { + var scheduler = await _schedulerFactory.GetScheduler(); + + var job = JobBuilder.Create() + .WithIdentity($"disable-provider-orgs-{providerId}", "provider-management") + .UsingJobData("providerId", providerId.ToString()) + .UsingJobData("expirationDate", expirationDate?.ToString("O")) + .Build(); + + var trigger = TriggerBuilder.Create() + .WithIdentity($"disable-trigger-{providerId}", "provider-management") + .StartNow() + .Build(); + + await scheduler.ScheduleJob(job, trigger); + } } diff --git a/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs b/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs new file mode 100644 index 0000000000..91b38341e5 --- /dev/null +++ b/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs @@ -0,0 +1,234 @@ +using Bit.Billing.Jobs; +using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Quartz; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class ProviderOrganizationDisableJobTests +{ + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly ILogger _logger; + private readonly ProviderOrganizationDisableJob _sut; + + public ProviderOrganizationDisableJobTests() + { + _providerOrganizationRepository = Substitute.For(); + _organizationDisableCommand = Substitute.For(); + _logger = Substitute.For>(); + _sut = new ProviderOrganizationDisableJob( + _providerOrganizationRepository, + _organizationDisableCommand, + _logger); + } + + [Fact] + public async Task Execute_NoOrganizations_LogsAndReturns() + { + // Arrange + var providerId = Guid.NewGuid(); + var context = CreateJobExecutionContext(providerId, DateTime.UtcNow); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns((ICollection)null); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); + } + + [Fact] + public async Task Execute_WithOrganizations_DisablesAllOrganizations() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var org1Id = Guid.NewGuid(); + var org2Id = Guid.NewGuid(); + var org3Id = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = org1Id }, + new() { OrganizationId = org2Id }, + new() { OrganizationId = org3Id } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any()); + } + + [Fact] + public async Task Execute_WithExpirationDate_PassesDateToDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = new DateTime(2025, 12, 31, 23, 59, 59); + var orgId = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = orgId } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(orgId, expirationDate); + } + + [Fact] + public async Task Execute_WithNullExpirationDate_PassesNullToDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var orgId = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = orgId } + }; + + var context = CreateJobExecutionContext(providerId, null); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(orgId, null); + } + + [Fact] + public async Task Execute_OneOrganizationFails_ContinuesProcessingOthers() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var org1Id = Guid.NewGuid(); + var org2Id = Guid.NewGuid(); + var org3Id = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = org1Id }, + new() { OrganizationId = org2Id }, + new() { OrganizationId = org3Id } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Make org2 fail + _organizationDisableCommand.DisableAsync(org2Id, Arg.Any()) + .Throws(new Exception("Database error")); + + // Act + await _sut.Execute(context); + + // Assert - all three should be attempted + await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any()); + } + + [Fact] + public async Task Execute_ManyOrganizations_ProcessesWithLimitedConcurrency() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + + // Create 20 organizations + var organizations = Enumerable.Range(1, 20) + .Select(_ => new ProviderOrganizationOrganizationDetails { OrganizationId = Guid.NewGuid() }) + .ToList(); + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + var concurrentCalls = 0; + var maxConcurrentCalls = 0; + var lockObj = new object(); + + _organizationDisableCommand.DisableAsync(Arg.Any(), Arg.Any()) + .Returns(callInfo => + { + lock (lockObj) + { + concurrentCalls++; + if (concurrentCalls > maxConcurrentCalls) + { + maxConcurrentCalls = concurrentCalls; + } + } + + return Task.Delay(50).ContinueWith(_ => + { + lock (lockObj) + { + concurrentCalls--; + } + }); + }); + + // Act + await _sut.Execute(context); + + // Assert + Assert.True(maxConcurrentCalls <= 5, $"Expected max concurrency of 5, but got {maxConcurrentCalls}"); + await _organizationDisableCommand.Received(20).DisableAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_EmptyOrganizationsList_DoesNotCallDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var context = CreateJobExecutionContext(providerId, DateTime.UtcNow); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(new List()); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); + } + + private static IJobExecutionContext CreateJobExecutionContext(Guid providerId, DateTime? expirationDate) + { + var context = Substitute.For(); + var jobDataMap = new JobDataMap + { + { "providerId", providerId.ToString() }, + { "expirationDate", expirationDate?.ToString("O") } + }; + context.MergedJobDataMap.Returns(jobDataMap); + return context; + } +} diff --git a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs index 78dc5aa791..de2d3ec0ed 100644 --- a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs @@ -1,10 +1,15 @@ using Bit.Billing.Constants; +using Bit.Billing.Jobs; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Extensions; using Bit.Core.Services; using NSubstitute; +using Quartz; using Stripe; using Xunit; @@ -16,6 +21,10 @@ public class SubscriptionDeletedHandlerTests private readonly IUserService _userService; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ISchedulerFactory _schedulerFactory; + private readonly IScheduler _scheduler; private readonly SubscriptionDeletedHandler _sut; public SubscriptionDeletedHandlerTests() @@ -24,11 +33,19 @@ public class SubscriptionDeletedHandlerTests _userService = Substitute.For(); _stripeEventUtilityService = Substitute.For(); _organizationDisableCommand = Substitute.For(); + _providerRepository = Substitute.For(); + _providerService = Substitute.For(); + _schedulerFactory = Substitute.For(); + _scheduler = Substitute.For(); + _schedulerFactory.GetScheduler().Returns(_scheduler); _sut = new SubscriptionDeletedHandler( _stripeEventService, _userService, _stripeEventUtilityService, - _organizationDisableCommand); + _organizationDisableCommand, + _providerRepository, + _providerService, + _schedulerFactory); } [Fact] @@ -59,6 +76,7 @@ public class SubscriptionDeletedHandlerTests // Assert await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); await _userService.DidNotReceiveWithAnyArgs().DisablePremiumAsync(default, default); + await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default); } [Fact] @@ -192,4 +210,120 @@ public class SubscriptionDeletedHandlerTests await _organizationDisableCommand.DidNotReceiveWithAnyArgs() .DisableAsync(default, default); } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_DisablesProviderAndQueuesJob() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var provider = new Provider + { + Id = providerId, + Enabled = true + }; + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await _sut.HandleAsync(stripeEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _scheduler.Received(1).ScheduleJob( + Arg.Is(j => j.JobType == typeof(ProviderOrganizationDisableJob)), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_ProviderNotFound_DoesNotThrow() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns((Provider)null); + + // Act & Assert - Should not throw + await _sut.HandleAsync(stripeEvent); + + // Assert + await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default); + await _scheduler.DidNotReceiveWithAnyArgs().ScheduleJob(default, default); + } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_QueuesJobWithCorrectParameters() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var provider = new Provider + { + Id = providerId, + Enabled = true + }; + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = expirationDate } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await _sut.HandleAsync(stripeEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _scheduler.Received(1).ScheduleJob( + Arg.Is(j => + j.JobType == typeof(ProviderOrganizationDisableJob) && + j.JobDataMap.GetString("providerId") == providerId.ToString() && + j.JobDataMap.GetString("expirationDate") == expirationDate.ToString("O")), + Arg.Is(t => t.Key.Name == $"disable-trigger-{providerId}")); + } }