diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs
index df333d5d4e..c0c138d0bc 100644
--- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs
+++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs
@@ -22,6 +22,7 @@ using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Billing.Providers.Models;
using Bit.Core.Billing.Providers.Repositories;
using Bit.Core.Billing.Providers.Services;
+using Bit.Core.Billing.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
@@ -53,6 +54,7 @@ public class ProvidersController : Controller
private readonly IPricingClient _pricingClient;
private readonly IStripeAdapter _stripeAdapter;
private readonly IAccessControlService _accessControlService;
+ private readonly ISubscriberService _subscriberService;
private readonly string _stripeUrl;
private readonly string _braintreeMerchantUrl;
private readonly string _braintreeMerchantId;
@@ -73,7 +75,8 @@ public class ProvidersController : Controller
IWebHostEnvironment webHostEnvironment,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
- IAccessControlService accessControlService)
+ IAccessControlService accessControlService,
+ ISubscriberService subscriberService)
{
_organizationRepository = organizationRepository;
_resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand;
@@ -93,6 +96,7 @@ public class ProvidersController : Controller
_braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl();
_braintreeMerchantId = globalSettings.Braintree.MerchantId;
_accessControlService = accessControlService;
+ _subscriberService = subscriberService;
}
[RequirePermission(Permission.Provider_List_View)]
@@ -299,6 +303,23 @@ public class ProvidersController : Controller
model.ToProvider(provider);
+ // validate the stripe ids to prevent saving a bad one
+ if (provider.IsBillable())
+ {
+ if (!await _subscriberService.IsValidGatewayCustomerIdAsync(provider))
+ {
+ var oldModel = await GetEditModel(id);
+ ModelState.AddModelError(nameof(model.GatewayCustomerId), $"Invalid Gateway Customer Id: {model.GatewayCustomerId}");
+ return View(oldModel);
+ }
+ if (!await _subscriberService.IsValidGatewaySubscriptionIdAsync(provider))
+ {
+ var oldModel = await GetEditModel(id);
+ ModelState.AddModelError(nameof(model.GatewaySubscriptionId), $"Invalid Gateway Subscription Id: {model.GatewaySubscriptionId}");
+ return View(oldModel);
+ }
+ }
+
provider.Enabled = _accessControlService.UserHasPermission(Permission.Provider_CheckEnabledBox)
? model.Enabled : originalProviderStatus;
@@ -382,10 +403,8 @@ public class ProvidersController : Controller
}
var providerPlans = await _providerPlanRepository.GetByProviderId(id);
-
- var payByInvoice =
- _featureService.IsEnabled(FeatureFlagKeys.PM199566_UpdateMSPToChargeAutomatically) &&
- (await _stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId)).ApprovedToPayByInvoice();
+ var payByInvoice = _featureService.IsEnabled(FeatureFlagKeys.PM199566_UpdateMSPToChargeAutomatically) &&
+ ((await _subscriberService.GetCustomer(provider))?.ApprovedToPayByInvoice() ?? false);
return new ProviderEditModel(
provider, users, providerOrganizations,
diff --git a/src/Admin/AdminConsole/Models/ProviderEditModel.cs b/src/Admin/AdminConsole/Models/ProviderEditModel.cs
index 450dfbb2fc..a96c3bd236 100644
--- a/src/Admin/AdminConsole/Models/ProviderEditModel.cs
+++ b/src/Admin/AdminConsole/Models/ProviderEditModel.cs
@@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.Billing.Enums;
+using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Providers.Entities;
using Bit.Core.Enums;
using Bit.SharedWeb.Utilities;
@@ -87,14 +88,13 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant().Trim();
existingProvider.BillingPhone = BillingPhone?.ToLowerInvariant().Trim();
existingProvider.Enabled = Enabled;
- switch (Type)
+ if (Type.IsStripeSupported())
{
- case ProviderType.Msp:
- existingProvider.Gateway = Gateway;
- existingProvider.GatewayCustomerId = GatewayCustomerId;
- existingProvider.GatewaySubscriptionId = GatewaySubscriptionId;
- break;
+ existingProvider.Gateway = Gateway;
+ existingProvider.GatewayCustomerId = GatewayCustomerId;
+ existingProvider.GatewaySubscriptionId = GatewaySubscriptionId;
}
+
return existingProvider;
}
diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs
index c8a1496726..55db9dde18 100644
--- a/src/Core/Billing/Extensions/BillingExtensions.cs
+++ b/src/Core/Billing/Extensions/BillingExtensions.cs
@@ -36,6 +36,10 @@ public static class BillingExtensions
Status: ProviderStatusType.Billable
};
+ // Reseller types do not have Stripe entities
+ public static bool IsStripeSupported(this ProviderType providerType) =>
+ providerType is ProviderType.Msp or ProviderType.BusinessUnit;
+
public static bool SupportsConsolidatedBilling(this ProviderType providerType)
=> providerType is ProviderType.Msp or ProviderType.BusinessUnit;
diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs
index 5f656b2c22..f88727f37b 100644
--- a/src/Core/Billing/Services/ISubscriberService.cs
+++ b/src/Core/Billing/Services/ISubscriberService.cs
@@ -157,4 +157,22 @@ public interface ISubscriberService
Task VerifyBankAccount(
ISubscriber subscriber,
string descriptorCode);
+
+ ///
+ /// Validates whether the 's exists in the gateway.
+ /// If the 's is or empty, returns .
+ ///
+ /// The subscriber whose gateway customer ID should be validated.
+ /// if the gateway customer ID is valid or empty; if the customer doesn't exist in the gateway.
+ /// Thrown when the is .
+ Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber);
+
+ ///
+ /// Validates whether the 's exists in the gateway.
+ /// If the 's is or empty, returns .
+ ///
+ /// The subscriber whose gateway subscription ID should be validated.
+ /// if the gateway subscription ID is valid or empty; if the subscription doesn't exist in the gateway.
+ /// Thrown when the is .
+ Task IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber);
}
diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs
index 73696846ac..53f033de00 100644
--- a/src/Core/Billing/Services/Implementations/SubscriberService.cs
+++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs
@@ -909,6 +909,44 @@ public class SubscriberService(
}
}
+ public async Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber)
+ {
+ ArgumentNullException.ThrowIfNull(subscriber);
+ if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
+ {
+ // subscribers are allowed to have no customer id as a business rule
+ return true;
+ }
+ try
+ {
+ await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId);
+ return true;
+ }
+ catch (StripeException e) when (e.StripeError.Code == "resource_missing")
+ {
+ return false;
+ }
+ }
+
+ public async Task IsValidGatewaySubscriptionIdAsync(ISubscriber subscriber)
+ {
+ ArgumentNullException.ThrowIfNull(subscriber);
+ if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
+ {
+ // subscribers are allowed to have no subscription id as a business rule
+ return true;
+ }
+ try
+ {
+ await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId);
+ return true;
+ }
+ catch (StripeException e) when (e.StripeError.Code == "resource_missing")
+ {
+ return false;
+ }
+ }
+
#region Shared Utilities
private async Task AddBraintreeCustomerIdAsync(
diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs
index 3fb134fda8..c41fa81524 100644
--- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs
+++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs
@@ -1765,4 +1765,142 @@ public class SubscriberServiceTests
}
#endregion
+
+ #region IsValidGatewayCustomerIdAsync
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewayCustomerIdAsync_NullSubscriber_ThrowsArgumentNullException(
+ SutProvider sutProvider)
+ {
+ await Assert.ThrowsAsync(() =>
+ sutProvider.Sut.IsValidGatewayCustomerIdAsync(null));
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewayCustomerIdAsync_NullGatewayCustomerId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewayCustomerId = null;
+
+ var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
+
+ Assert.True(result);
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs()
+ .CustomerGetAsync(Arg.Any());
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewayCustomerIdAsync_EmptyGatewayCustomerId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewayCustomerId = "";
+
+ var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
+
+ Assert.True(result);
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs()
+ .CustomerGetAsync(Arg.Any());
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewayCustomerIdAsync_ValidCustomerId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var stripeAdapter = sutProvider.GetDependency();
+ stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer());
+
+ var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
+
+ Assert.True(result);
+ await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId);
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewayCustomerIdAsync_InvalidCustomerId_ReturnsFalse(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var stripeAdapter = sutProvider.GetDependency();
+ var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } };
+ stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException);
+
+ var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization);
+
+ Assert.False(result);
+ await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId);
+ }
+
+ #endregion
+
+ #region IsValidGatewaySubscriptionIdAsync
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewaySubscriptionIdAsync_NullSubscriber_ThrowsArgumentNullException(
+ SutProvider sutProvider)
+ {
+ await Assert.ThrowsAsync(() =>
+ sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(null));
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewaySubscriptionIdAsync_NullGatewaySubscriptionId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewaySubscriptionId = null;
+
+ var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
+
+ Assert.True(result);
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs()
+ .SubscriptionGetAsync(Arg.Any());
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewaySubscriptionIdAsync_EmptyGatewaySubscriptionId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewaySubscriptionId = "";
+
+ var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
+
+ Assert.True(result);
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs()
+ .SubscriptionGetAsync(Arg.Any());
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewaySubscriptionIdAsync_ValidSubscriptionId_ReturnsTrue(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var stripeAdapter = sutProvider.GetDependency();
+ stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription());
+
+ var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
+
+ Assert.True(result);
+ await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId);
+ }
+
+ [Theory, BitAutoData]
+ public async Task IsValidGatewaySubscriptionIdAsync_InvalidSubscriptionId_ReturnsFalse(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var stripeAdapter = sutProvider.GetDependency();
+ var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } };
+ stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException);
+
+ var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization);
+
+ Assert.False(result);
+ await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId);
+ }
+
+ #endregion
}