diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index df333d5d4e..c0c138d0bc 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -22,6 +22,7 @@ using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -53,6 +54,7 @@ public class ProvidersController : Controller private readonly IPricingClient _pricingClient; private readonly IStripeAdapter _stripeAdapter; private readonly IAccessControlService _accessControlService; + private readonly ISubscriberService _subscriberService; private readonly string _stripeUrl; private readonly string _braintreeMerchantUrl; private readonly string _braintreeMerchantId; @@ -73,7 +75,8 @@ public class ProvidersController : Controller IWebHostEnvironment webHostEnvironment, IPricingClient pricingClient, IStripeAdapter stripeAdapter, - IAccessControlService accessControlService) + IAccessControlService accessControlService, + ISubscriberService subscriberService) { _organizationRepository = organizationRepository; _resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand; @@ -93,6 +96,7 @@ public class ProvidersController : Controller _braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl(); _braintreeMerchantId = globalSettings.Braintree.MerchantId; _accessControlService = accessControlService; + _subscriberService = subscriberService; } [RequirePermission(Permission.Provider_List_View)] @@ -299,6 +303,23 @@ public class ProvidersController : Controller model.ToProvider(provider); + // validate the stripe ids to prevent saving a bad one + if (provider.IsBillable()) + { + if (!await _subscriberService.IsValidGatewayCustomerIdAsync(provider)) + { + var oldModel = await GetEditModel(id); + ModelState.AddModelError(nameof(model.GatewayCustomerId), $"Invalid Gateway Customer Id: {model.GatewayCustomerId}"); + return View(oldModel); + } + if (!await _subscriberService.IsValidGatewaySubscriptionIdAsync(provider)) + { + var oldModel = await GetEditModel(id); + ModelState.AddModelError(nameof(model.GatewaySubscriptionId), $"Invalid Gateway Subscription Id: {model.GatewaySubscriptionId}"); + return View(oldModel); + } + } + provider.Enabled = _accessControlService.UserHasPermission(Permission.Provider_CheckEnabledBox) ? model.Enabled : originalProviderStatus; @@ -382,10 +403,8 @@ public class ProvidersController : Controller } var providerPlans = await _providerPlanRepository.GetByProviderId(id); - - var payByInvoice = - _featureService.IsEnabled(FeatureFlagKeys.PM199566_UpdateMSPToChargeAutomatically) && - (await _stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId)).ApprovedToPayByInvoice(); + var payByInvoice = _featureService.IsEnabled(FeatureFlagKeys.PM199566_UpdateMSPToChargeAutomatically) && + ((await _subscriberService.GetCustomer(provider))?.ApprovedToPayByInvoice() ?? false); return new ProviderEditModel( provider, users, providerOrganizations, diff --git a/src/Admin/AdminConsole/Models/ProviderEditModel.cs b/src/Admin/AdminConsole/Models/ProviderEditModel.cs index 450dfbb2fc..a96c3bd236 100644 --- a/src/Admin/AdminConsole/Models/ProviderEditModel.cs +++ b/src/Admin/AdminConsole/Models/ProviderEditModel.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Enums; using Bit.SharedWeb.Utilities; @@ -87,14 +88,13 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant().Trim(); existingProvider.BillingPhone = BillingPhone?.ToLowerInvariant().Trim(); existingProvider.Enabled = Enabled; - switch (Type) + if (Type.IsStripeSupported()) { - case ProviderType.Msp: - existingProvider.Gateway = Gateway; - existingProvider.GatewayCustomerId = GatewayCustomerId; - existingProvider.GatewaySubscriptionId = GatewaySubscriptionId; - break; + existingProvider.Gateway = Gateway; + existingProvider.GatewayCustomerId = GatewayCustomerId; + existingProvider.GatewaySubscriptionId = GatewaySubscriptionId; } + return existingProvider; } diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs index c8a1496726..55db9dde18 100644 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -36,6 +36,10 @@ public static class BillingExtensions Status: ProviderStatusType.Billable }; + // Reseller types do not have Stripe entities + public static bool IsStripeSupported(this ProviderType providerType) => + providerType is ProviderType.Msp or ProviderType.BusinessUnit; + public static bool SupportsConsolidatedBilling(this ProviderType providerType) => providerType is ProviderType.Msp or ProviderType.BusinessUnit; diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index 5f656b2c22..f88727f37b 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -157,4 +157,22 @@ public interface ISubscriberService Task VerifyBankAccount( ISubscriber subscriber, string descriptorCode); + + /// + /// Validates whether the 's exists in the gateway. + /// If the 's is or empty, returns . + /// + /// The subscriber whose gateway customer ID should be validated. + /// if the gateway customer ID is valid or empty; if the customer doesn't exist in the gateway. + /// Thrown when the is . + Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber); + + /// + /// Validates whether the 's exists in the gateway. + /// If the 's is or empty, returns . + /// + /// The subscriber whose gateway subscription ID should be validated. + /// if the gateway subscription ID is valid or empty; if the subscription doesn't exist in the gateway. + /// Thrown when the is . + Task IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber); } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 73696846ac..53f033de00 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -909,6 +909,44 @@ public class SubscriberService( } } + public async Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + // subscribers are allowed to have no customer id as a business rule + return true; + } + try + { + await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + return true; + } + catch (StripeException e) when (e.StripeError.Code == "resource_missing") + { + return false; + } + } + + public async Task IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + // subscribers are allowed to have no subscription id as a business rule + return true; + } + try + { + await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + return true; + } + catch (StripeException e) when (e.StripeError.Code == "resource_missing") + { + return false; + } + } + #region Shared Utilities private async Task AddBraintreeCustomerIdAsync( diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 3fb134fda8..c41fa81524 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -1765,4 +1765,142 @@ public class SubscriberServiceTests } #endregion + + #region IsValidGatewayCustomerIdAsync + + [Theory, BitAutoData] + public async Task IsValidGatewayCustomerIdAsync_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => + sutProvider.Sut.IsValidGatewayCustomerIdAsync(null)); + } + + [Theory, BitAutoData] + public async Task IsValidGatewayCustomerIdAsync_NullGatewayCustomerId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + + var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); + + Assert.True(result); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task IsValidGatewayCustomerIdAsync_EmptyGatewayCustomerId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = ""; + + var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); + + Assert.True(result); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task IsValidGatewayCustomerIdAsync_ValidCustomerId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer()); + + var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); + + Assert.True(result); + await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + } + + [Theory, BitAutoData] + public async Task IsValidGatewayCustomerIdAsync_InvalidCustomerId_ReturnsFalse( + Organization organization, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; + stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException); + + var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); + + Assert.False(result); + await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + } + + #endregion + + #region IsValidGatewaySubscriptionIdAsync + + [Theory, BitAutoData] + public async Task IsValidGatewaySubscriptionIdAsync_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => + sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(null)); + } + + [Theory, BitAutoData] + public async Task IsValidGatewaySubscriptionIdAsync_NullGatewaySubscriptionId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); + + Assert.True(result); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .SubscriptionGetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task IsValidGatewaySubscriptionIdAsync_EmptyGatewaySubscriptionId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = ""; + + var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); + + Assert.True(result); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .SubscriptionGetAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task IsValidGatewaySubscriptionIdAsync_ValidSubscriptionId_ReturnsTrue( + Organization organization, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); + + var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); + + Assert.True(result); + await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + } + + [Theory, BitAutoData] + public async Task IsValidGatewaySubscriptionIdAsync_InvalidSubscriptionId_ReturnsFalse( + Organization organization, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; + stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException); + + var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); + + Assert.False(result); + await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + } + + #endregion }