diff --git a/.gitignore b/.gitignore index 60fc894285..db8cb50f84 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,7 @@ bitwarden_license/src/Sso/Sso.zip /identity.json /api.json /api.public.json +.serena/ # Serena .serena/ diff --git a/Directory.Build.props b/Directory.Build.props index 221200147c..db3ccf40f5 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.12.0 + 2025.12.2 Bit.$(MSBuildProjectName) enable diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 994b305349..12d370395c 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -113,7 +113,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv await _providerBillingService.CreateCustomerForClientOrganization(provider, organization); } - var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + var customer = await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Description = string.Empty, Email = organization.BillingEmail, @@ -138,7 +138,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await _stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; organization.Status = OrganizationStatusType.Created; @@ -148,27 +148,26 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv } else if (organization.IsStripeEnabled()) { - var subscription = await _stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions + var subscription = await _stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer"] }); - if (subscription.Status is StripeConstants.SubscriptionStatus.Canceled or StripeConstants.SubscriptionStatus.IncompleteExpired) { return; } - await _stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail }); if (subscription.Customer.Discount?.Coupon != null) { - await _stripeAdapter.CustomerDeleteDiscountAsync(subscription.CustomerId); + await _stripeAdapter.DeleteCustomerDiscountAsync(subscription.CustomerId); } - await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, DaysUntilDue = 30, diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index 89ef251fd6..3d18e95f7b 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -15,6 +15,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -427,7 +428,7 @@ public class ProviderService : IProviderService if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = provider.BillingEmail }); @@ -487,7 +488,7 @@ public class ProviderService : IProviderService private async Task GetSubscriptionItemAsync(string subscriptionId, string oldPlanId) { - var subscriptionDetails = await _stripeAdapter.SubscriptionGetAsync(subscriptionId); + var subscriptionDetails = await _stripeAdapter.GetSubscriptionAsync(subscriptionId); return subscriptionDetails.Items.Data.FirstOrDefault(item => item.Price.Id == oldPlanId); } @@ -497,7 +498,7 @@ public class ProviderService : IProviderService { if (subscriptionItem.Price.Id != extractedPlanType) { - await _stripeAdapter.SubscriptionUpdateAsync(subscriptionItem.Subscription, + await _stripeAdapter.UpdateSubscriptionAsync(subscriptionItem.Subscription, new Stripe.SubscriptionUpdateOptions { Items = new List diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs index cc77797307..e140a13841 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -76,8 +75,8 @@ public class GetProviderWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs index 8e8a89ae58..ce2f7a941f 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs @@ -101,7 +101,7 @@ public class BusinessUnitConverter( providerUser.Status = ProviderUserStatusType.Confirmed; // Stripe requires that we clear all the custom fields from the invoice settings if we want to replace them. - await stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -116,7 +116,7 @@ public class BusinessUnitConverter( ["convertedFrom"] = organization.Id.ToString() }; - var updateCustomer = stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + var updateCustomer = stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -148,7 +148,7 @@ public class BusinessUnitConverter( // Replace the existing password manager price with the new business unit price. var updateSubscription = - stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index e352297f1e..41734663c2 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -61,11 +61,11 @@ public class ProviderBillingService( Organization organization, string key) { - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = - await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, + await stripeAdapter.CancelSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionCancelOptions { CancellationDetails = new SubscriptionCancellationDetailsOptions @@ -83,7 +83,7 @@ public class ProviderBillingService( if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft) { - await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, + await stripeAdapter.FinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); } @@ -138,7 +138,7 @@ public class ProviderBillingService( if (clientCustomer.Balance != 0) { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, + await stripeAdapter.CreateCustomerBalanceTransactionAsync(provider.GatewayCustomerId, new CustomerBalanceTransactionCreateOptions { Amount = clientCustomer.Balance, @@ -187,7 +187,7 @@ public class ProviderBillingService( ] }; - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan @@ -275,7 +275,7 @@ public class ProviderBillingService( customerCreateOptions.TaxExempt = TaxExempt.Reverse; } - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; @@ -525,7 +525,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) @@ -558,7 +558,7 @@ public class ProviderBillingService( try { - return await stripeAdapter.CustomerCreateAsync(options); + return await stripeAdapter.CreateCustomerAsync(options); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid) { @@ -580,7 +580,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); - await stripeAdapter.SetupIntentCancel(setupIntentId, + await stripeAdapter.CancelSetupIntentAsync(setupIntentId, new SetupIntentCancelOptions { CancellationReason = "abandoned" }); await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); break; @@ -638,7 +638,7 @@ public class ProviderBillingService( var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntent = !string.IsNullOrEmpty(setupIntentId) - ? await stripeAdapter.SetupIntentGet(setupIntentId, + ? await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }) : null; @@ -673,7 +673,7 @@ public class ProviderBillingService( try { - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (subscription is { @@ -708,7 +708,7 @@ public class ProviderBillingService( subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource), subscriberService.UpdateTaxInformation(provider, taxInformation)); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically }); } @@ -791,7 +791,7 @@ public class ProviderBillingService( if (subscriptionItemOptionsList.Count > 0) { - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } @@ -807,7 +807,7 @@ public class ProviderBillingService( var item = subscription.Items.First(item => item.Price.Id == priceId); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 5b4a0c29cd..696d600348 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.E using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.Utilities.Commands; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -24,7 +25,7 @@ public class PostUserCommand( IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, - IPaymentService paymentService, + IStripePaymentService paymentService, IScimContext scimContext, IFeatureService featureService, IInviteOrganizationUsersCommand inviteOrganizationUsersCommand, diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index bc26fb270a..7141f8429d 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -201,12 +201,15 @@ public class AccountController : Controller returnUrl, state = context.Parameters["state"], userIdentifier = context.Parameters["session_state"], + ssoToken }); } [HttpGet] - public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) + public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier, string ssoToken) { + ValidateSchemeAgainstSsoToken(scheme, ssoToken); + if (string.IsNullOrEmpty(returnUrl)) { returnUrl = "~/"; @@ -235,6 +238,31 @@ public class AccountController : Controller return Challenge(props, scheme); } + /// + /// Validates the scheme (organization ID) against the organization ID found in the ssoToken. + /// + /// The authentication scheme (organization ID) to validate. + /// The SSO token to validate against. + /// Thrown if the scheme (organization ID) does not match the organization ID found in the ssoToken. + private void ValidateSchemeAgainstSsoToken(string scheme, string ssoToken) + { + SsoTokenable tokenable; + + try + { + tokenable = _dataProtector.Unprotect(ssoToken); + } + catch + { + throw new Exception(_i18nService.T("InvalidSsoToken")); + } + + if (!Guid.TryParse(scheme, out var schemeOrgId) || tokenable.OrganizationId != schemeOrgId) + { + throw new Exception(_i18nService.T("SsoOrganizationIdMismatch")); + } + } + [HttpGet] public async Task ExternalCallback() { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index b367b17c73..810429d658 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -131,7 +131,7 @@ public class RemoveOrganizationFromProviderCommandTests Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -156,7 +156,7 @@ public class RemoveOrganizationFromProviderCommandTests "b@example.com" ]); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId, Arg.Is( + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer"))) .Returns(GetSubscription(organization.GatewaySubscriptionId, organization.GatewayCustomerId)); @@ -164,12 +164,14 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - await stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, + await stripeAdapter.Received(1).UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Email == "a@example.com")); - await stripeAdapter.Received(1).CustomerDeleteDiscountAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); + + await stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30)); @@ -226,7 +228,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -239,14 +241,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -315,7 +317,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -328,14 +330,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -434,7 +436,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Any()) + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(new Customer { Id = "customer_id", @@ -444,7 +446,7 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "new_subscription_id" }); diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index 78376f6d98..11ffe115e2 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -12,6 +12,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -757,7 +758,7 @@ public class ProviderServiceTests await organizationRepository.Received(1) .ReplaceAsync(Arg.Is(org => org.BillingEmail == provider.BillingEmail)); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync( organization.GatewayCustomerId, Arg.Is(options => options.Email == provider.BillingEmail)); @@ -828,9 +829,9 @@ public class ProviderServiceTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var subscriptionItem = GetSubscription(organization.GatewaySubscriptionId); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(GetSubscription(organization.GatewaySubscriptionId)); - await sutProvider.GetDependency().SubscriptionUpdateAsync( + await sutProvider.GetDependency().UpdateSubscriptionAsync( organization.GatewaySubscriptionId, SubscriptionUpdateRequest(expectedPlanId, subscriptionItem)); await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs index a7f896ef7a..96dbacfa92 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -63,7 +62,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -95,7 +94,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -129,7 +128,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(false); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -163,7 +162,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -224,7 +223,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "GB" }] @@ -257,7 +256,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -296,7 +295,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -338,7 +337,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -383,7 +382,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -428,7 +427,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -461,7 +460,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) .Returns(new StripeList { Data = [ @@ -470,7 +469,7 @@ public class GetProviderWarningsQueryTests new Registration { Country = "FR" } ] }); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -505,7 +504,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -543,7 +542,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "US" }] diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs index c893886083..48b971a032 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs @@ -144,11 +144,11 @@ public class BusinessUnitConverterTests await businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey); - await _stripeAdapter.Received(2).CustomerUpdateAsync(subscription.CustomerId, Arg.Any()); + await _stripeAdapter.Received(2).UpdateCustomerAsync(subscription.CustomerId, Arg.Any()); var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, enterpriseAnnually.Type); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(subscription.Id, Arg.Is( + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(subscription.Id, Arg.Is( arguments => arguments.Items.Count == 2 && arguments.Items[0].Id == "subscription_item_id" && diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index daf35e7ae9..76c5b30dd8 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -20,7 +20,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; @@ -85,7 +84,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -113,7 +112,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -180,14 +179,14 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => si.Id == "si_ent_annual" && si.Deleted == true) == 1)); var newPlanCfg = MockPlans.Get(command.NewPlan); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => @@ -268,7 +267,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -288,7 +287,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -349,7 +348,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -370,7 +369,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -535,7 +534,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync( + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpdateSubscriptionAsync( Arg.Any(), Arg.Any()); @@ -619,7 +618,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -707,7 +706,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -795,7 +794,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -914,12 +913,12 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -942,7 +941,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); - await stripeAdapter.Received(1).SetupIntentCancel("setup_intent_id", Arg.Is(options => + await stripeAdapter.Received(1).CancelSetupIntentAsync("setup_intent_id", Arg.Is(options => options.CancellationReason == "abandoned")); await sutProvider.GetDependency().Received(1).RemoveSetupIntentForSubscriber(provider.Id); @@ -964,7 +963,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1007,12 +1006,12 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1058,7 +1057,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1100,7 +1099,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1142,7 +1141,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1178,7 +1177,7 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Any()) + stripeAdapter.CreateCustomerAsync(Arg.Any()) .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } }); var actual = await Assert.ThrowsAsync(async () => @@ -1216,7 +1215,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1244,7 +1243,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1272,7 +1271,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1323,7 +1322,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()) + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Any()) .Returns( new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); @@ -1381,7 +1380,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && @@ -1458,7 +1457,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1538,7 +1537,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntentId); - sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is(options => + sutProvider.GetDependency().GetSetupIntentAsync(setupIntentId, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(new SetupIntent { Id = setupIntentId, @@ -1553,7 +1552,7 @@ public class ProviderBillingServiceTests } }); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1635,7 +1634,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1713,7 +1712,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1828,7 +1827,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 20 && providerPlan.PurchasedSeats == 5)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1908,7 +1907,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 50)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1989,7 +1988,7 @@ public class ProviderBillingServiceTests providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 60 && providerPlan.PurchasedSeats == 10)); await stripeAdapter.DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -2062,7 +2061,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 80 && providerPlan.PurchasedSeats == 0)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -2142,7 +2141,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.DidNotReceive().ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 1 && diff --git a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs index c04948e21f..b276174814 100644 --- a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs +++ b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs @@ -3,6 +3,7 @@ using System.Security.Claims; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.Registration; @@ -10,6 +11,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Tokens; using Bit.Sso.Controllers; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -1137,4 +1139,129 @@ public class AccountControllerTest Assert.NotNull(result.user); Assert.Equal(email, result.user.Email); } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMatchingOrgId_Succeeds( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var orgId = organization.Id; + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with matching org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock URL helper for IsLocalUrl check + var urlHelper = Substitute.For(); + urlHelper.IsLocalUrl(returnUrl).Returns(true); + sutProvider.Sut.Url = urlHelper; + + // Mock interaction service for IsValidReturnUrl check + var interactionService = sutProvider.GetDependency(); + interactionService.IsValidReturnUrl(returnUrl).Returns(true); + + // Act + var result = sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken); + + // Assert + var challengeResult = Assert.IsType(result); + Assert.Contains(scheme, challengeResult.AuthenticationSchemes); + Assert.NotNull(challengeResult.Properties); + Assert.Equal(scheme, challengeResult.Properties.Items["scheme"]); + Assert.Equal(returnUrl, challengeResult.Properties.Items["return_url"]); + Assert.Equal(state, challengeResult.Properties.Items["state"]); + Assert.Equal(userIdentifier, challengeResult.Properties.Items["user_identifier"]); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMismatchedOrgId_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var correctOrgId = organization.Id; + var wrongOrgId = Guid.NewGuid(); + var scheme = wrongOrgId.ToString(); // Different from tokenable's org ID + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with different org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); // Contains correctOrgId + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSchemeFormat_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var scheme = "not-a-valid-guid"; + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a valid tokenable + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSsoToken_ThrowsInvalidSsoToken( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "invalid-corrupted-token"; + + // Mock the data protector to throw when trying to unprotect + var dataProtector = sutProvider.GetDependency>(); + dataProtector.Unprotect(ssoToken).Returns(_ => throw new Exception("Token validation failed")); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("InvalidSsoToken", ex.Message); + } } diff --git a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs index ac23e7ecc1..eb8804cac5 100644 --- a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -36,7 +37,7 @@ public class PostUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); sutProvider.GetDependency() .InviteUserAsync(organizationId, diff --git a/dev/secrets.json.example b/dev/secrets.json.example index c6a16846e9..0d4213aec1 100644 --- a/dev/secrets.json.example +++ b/dev/secrets.json.example @@ -33,6 +33,10 @@ "id": "", "key": "" }, + "events": { + "connectionString": "", + "queueName": "event" + }, "licenseDirectory": "", "enableNewDeviceVerification": true, "enableEmailVerification": true diff --git a/global.json b/global.json index d25197db39..4cbe3f083a 100644 --- a/global.json +++ b/global.json @@ -5,6 +5,7 @@ }, "msbuild-sdks": { "Microsoft.Build.Traversal": "4.1.0", - "Microsoft.Build.Sql": "1.0.0" + "Microsoft.Build.Sql": "1.0.0", + "Bitwarden.Server.Sdk": "1.2.0" } } diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 2ea539f39f..a99f70bf65 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -16,6 +16,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -41,7 +42,7 @@ public class OrganizationsController : Controller private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; private readonly IPolicyRepository _policyRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IApplicationCacheService _applicationCacheService; private readonly GlobalSettings _globalSettings; private readonly IProviderRepository _providerRepository; @@ -66,7 +67,7 @@ public class OrganizationsController : Controller ICollectionRepository collectionRepository, IGroupRepository groupRepository, IPolicyRepository policyRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IApplicationCacheService applicationCacheService, GlobalSettings globalSettings, IProviderRepository providerRepository, diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index 9344179a77..b6a959a386 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -339,11 +339,11 @@ public class ProvidersController : Controller ]); await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand); - var customer = await _stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId); if (model.PayByInvoice != customer.ApprovedToPayByInvoice()) { var approvedToPayByInvoice = model.PayByInvoice ? "1" : "0"; - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = new Dictionary { diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index 46dafd65e7..2dd6de89a0 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -8,6 +8,7 @@ using Bit.Admin.Utilities; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index b85a91719c..f42b22b098 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -5,6 +5,7 @@ using Bit.Admin.Models; using Bit.Admin.Services; using Bit.Admin.Utilities; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -20,7 +21,7 @@ public class UsersController : Controller { private readonly IUserRepository _userRepository; private readonly ICipherRepository _cipherRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly GlobalSettings _globalSettings; private readonly IAccessControlService _accessControlService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; @@ -30,7 +31,7 @@ public class UsersController : Controller public UsersController( IUserRepository userRepository, ICipherRepository cipherRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, GlobalSettings globalSettings, IAccessControlService accessControlService, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs index 0b7fe8dffe..f172a23529 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs @@ -1,8 +1,8 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; using Bit.Core.Context; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -12,8 +12,10 @@ namespace Bit.Api.AdminConsole.Controllers; [Authorize("Application")] public class OrganizationIntegrationConfigurationController( ICurrentContext currentContext, - IOrganizationIntegrationRepository integrationRepository, - IOrganizationIntegrationConfigurationRepository integrationConfigurationRepository) : Controller + ICreateOrganizationIntegrationConfigurationCommand createCommand, + IUpdateOrganizationIntegrationConfigurationCommand updateCommand, + IDeleteOrganizationIntegrationConfigurationCommand deleteCommand, + IGetOrganizationIntegrationConfigurationsQuery getQuery) : Controller { [HttpGet("")] public async Task> GetAsync( @@ -24,13 +26,8 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - var configurations = await integrationConfigurationRepository.GetManyByIntegrationAsync(integrationId); + var configurations = await getQuery.GetManyByIntegrationAsync(organizationId, integrationId); return configurations .Select(configuration => new OrganizationIntegrationConfigurationResponseModel(configuration)) .ToList(); @@ -46,19 +43,11 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - var organizationIntegrationConfiguration = model.ToOrganizationIntegrationConfiguration(integrationId); - var configuration = await integrationConfigurationRepository.CreateAsync(organizationIntegrationConfiguration); - return new OrganizationIntegrationConfigurationResponseModel(configuration); + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var created = await createCommand.CreateAsync(organizationId, integrationId, configuration); + + return new OrganizationIntegrationConfigurationResponseModel(created); } [HttpPut("{configurationId:guid}")] @@ -72,26 +61,11 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var updated = await updateCommand.UpdateAsync(organizationId, integrationId, configurationId, configuration); - var newConfiguration = model.ToOrganizationIntegrationConfiguration(configuration); - await integrationConfigurationRepository.ReplaceAsync(newConfiguration); - - return new OrganizationIntegrationConfigurationResponseModel(newConfiguration); + return new OrganizationIntegrationConfigurationResponseModel(updated); } [HttpDelete("{configurationId:guid}")] @@ -101,19 +75,8 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } - - await integrationConfigurationRepository.DeleteAsync(configuration); + await deleteCommand.DeleteAsync(organizationId, integrationId, configurationId); } [HttpPost("{configurationId:guid}/delete")] diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs index 8581c4ae1f..9341392d68 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs @@ -1,6 +1,4 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; @@ -16,38 +14,6 @@ public class OrganizationIntegrationConfigurationRequestModel public string? Template { get; set; } - public bool IsValidForType(IntegrationType integrationType) - { - switch (integrationType) - { - case IntegrationType.CloudBillingSync or IntegrationType.Scim: - return false; - case IntegrationType.Slack: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Webhook: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Hec: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - case IntegrationType.Datadog: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - case IntegrationType.Teams: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - default: - return false; - - } - } - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(Guid organizationIntegrationId) { return new OrganizationIntegrationConfiguration() @@ -59,50 +25,4 @@ public class OrganizationIntegrationConfigurationRequestModel Template = Template }; } - - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(OrganizationIntegrationConfiguration currentConfiguration) - { - currentConfiguration.Configuration = Configuration; - currentConfiguration.EventType = EventType; - currentConfiguration.Filters = Filters; - currentConfiguration.Template = Template; - - return currentConfiguration; - } - - private bool IsConfigurationValid() - { - if (string.IsNullOrWhiteSpace(Configuration)) - { - return false; - } - - try - { - var config = JsonSerializer.Deserialize(Configuration); - return config is not null; - } - catch - { - return false; - } - } - - private bool IsFiltersValid() - { - if (Filters is null) - { - return true; - } - - try - { - var filters = JsonSerializer.Deserialize(Filters); - return filters is not null; - } - catch - { - return false; - } - } } diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 3b2e82121d..58e5db18c2 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; using Bit.Core.Services; @@ -24,7 +25,7 @@ public class MembersController : Controller private readonly ICurrentContext _currentContext; private readonly IUpdateOrganizationUserCommand _updateOrganizationUserCommand; private readonly IUpdateOrganizationUserGroupsCommand _updateOrganizationUserGroupsCommand; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationRepository _organizationRepository; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; @@ -37,7 +38,7 @@ public class MembersController : Controller ICurrentContext currentContext, IUpdateOrganizationUserCommand updateOrganizationUserCommand, IUpdateOrganizationUserGroupsCommand updateOrganizationUserGroupsCommand, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IRemoveOrganizationUserCommand removeOrganizationUserCommand, diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index 99b6a47da0..243f4d3c53 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -10,7 +10,7 @@ namespace Bit.Api.Billing.Controllers; [Route("accounts/billing")] [Authorize("Application")] public class AccountsBillingController( - IPaymentService paymentService, + IStripePaymentService paymentService, IUserService userService, IPaymentHistoryService paymentHistoryService) : Controller { diff --git a/src/Api/Billing/Controllers/AccountsController.cs b/src/Api/Billing/Controllers/AccountsController.cs index e136513c77..5d3e095fdd 100644 --- a/src/Api/Billing/Controllers/AccountsController.cs +++ b/src/Api/Billing/Controllers/AccountsController.cs @@ -79,7 +79,7 @@ public class AccountsController( [HttpGet("subscription")] public async Task GetSubscriptionAsync( [FromServices] GlobalSettings globalSettings, - [FromServices] IPaymentService paymentService) + [FromServices] IStripePaymentService paymentService) { var user = await userService.GetUserByPrincipalAsync(User); if (user == null) diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index a0a3e48b60..e06d946ea0 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -5,7 +5,6 @@ using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -19,7 +18,7 @@ public class OrganizationBillingController( ICurrentContext currentContext, IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IPaymentHistoryService paymentHistoryService) : BaseBillingController { // TODO: Remove when pm-25379-use-new-organization-metadata-structure is removed. diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index 16fb00a3e7..bca5605a8c 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -36,7 +36,7 @@ public class OrganizationsController( IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, IUserService userService, - IPaymentService paymentService, + IStripePaymentService paymentService, ICurrentContext currentContext, IGetCloudOrganizationLicenseQuery getCloudOrganizationLicenseQuery, GlobalSettings globalSettings, diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index d358f8efd2..dfa705a329 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -43,7 +43,7 @@ public class ProviderBillingController( return result; } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = provider.GatewayCustomerId }); @@ -87,7 +87,7 @@ public class ProviderBillingController( return result; } - var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.tax_ids", "discounts", "test_clock"] }); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -96,7 +96,7 @@ public class ProviderBillingController( { var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, plan.Type); - var price = await stripeAdapter.PriceGetAsync(priceId); + var price = await stripeAdapter.GetPriceAsync(priceId); var unitAmount = price.UnitAmountDecimal.HasValue ? price.UnitAmountDecimal.Value / 100M diff --git a/src/Api/Billing/Controllers/StripeController.cs b/src/Api/Billing/Controllers/StripeController.cs index 15fccd16f4..6cb10e3165 100644 --- a/src/Api/Billing/Controllers/StripeController.cs +++ b/src/Api/Billing/Controllers/StripeController.cs @@ -1,5 +1,5 @@ -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -28,7 +28,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } @@ -43,7 +43,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj index fdac4fc3e4..69999dc795 100644 --- a/src/Billing/Billing.csproj +++ b/src/Billing/Billing.csproj @@ -1,9 +1,17 @@  + bitwarden-Billing + + + false + false + false + + diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index b24a8d8c36..f55b4523af 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -29,7 +29,7 @@ public class BitPayController( IUserRepository userRepository, IProviderRepository providerRepository, IMailService mailService, - IPaymentService paymentService, + IStripePaymentService paymentService, ILogger logger, IPremiumUserBillingService premiumUserBillingService) : Controller diff --git a/src/Billing/Controllers/PayPalController.cs b/src/Billing/Controllers/PayPalController.cs index 8039680fd5..70023b6bdb 100644 --- a/src/Billing/Controllers/PayPalController.cs +++ b/src/Billing/Controllers/PayPalController.cs @@ -23,7 +23,7 @@ public class PayPalController : Controller private readonly ILogger _logger; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ITransactionRepository _transactionRepository; private readonly IUserRepository _userRepository; private readonly IProviderRepository _providerRepository; @@ -34,7 +34,7 @@ public class PayPalController : Controller ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ITransactionRepository transactionRepository, IUserRepository userRepository, IProviderRepository providerRepository, diff --git a/src/Billing/Program.cs b/src/Billing/Program.cs index 72ff6072c5..334dc49368 100644 --- a/src/Billing/Program.cs +++ b/src/Billing/Program.cs @@ -8,6 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs index bc3fa1bd56..89e40f0e43 100644 --- a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs @@ -2,8 +2,8 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using OneOf; using Stripe; using Event = Stripe.Event; @@ -59,10 +59,10 @@ public class SetupIntentSucceededHandler( return; } - await stripeAdapter.PaymentMethodAttachAsync(paymentMethod.Id, + await stripeAdapter.AttachPaymentMethodAsync(paymentMethod.Id, new PaymentMethodAttachOptions { Customer = customerId }); - await stripeAdapter.CustomerUpdateAsync(customerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 07ffef064f..c10368d8c0 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -109,8 +109,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler break; } - if (subscription.Status is StripeSubscriptionStatus.Unpaid && - subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + if (await IsPremiumSubscriptionAsync(subscription)) { await CancelSubscription(subscription.Id); await VoidOpenInvoices(subscription.Id); @@ -118,6 +117,20 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + break; + } + case StripeSubscriptionStatus.Incomplete when userId.HasValue: + { + // Handle Incomplete subscriptions for Premium users that have open invoices from failed payments + // This prevents duplicate subscriptions when users retry the subscription flow + if (await IsPremiumSubscriptionAsync(subscription) && + subscription.LatestInvoice is { Status: StripeInvoiceStatus.Open }) + { + await CancelSubscription(subscription.Id); + await VoidOpenInvoices(subscription.Id); + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + } + break; } case StripeSubscriptionStatus.Active when organizationId.HasValue: @@ -190,6 +203,13 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } + private async Task IsPremiumSubscriptionAsync(Subscription subscription) + { + var premiumPlans = await _pricingClient.ListPremiumPlans(); + var premiumPriceIds = premiumPlans.SelectMany(p => new[] { p.Seat.StripePriceId, p.Storage.StripePriceId }).ToHashSet(); + return subscription.Items.Any(i => premiumPriceIds.Contains(i.Price.Id)); + } + /// /// Checks if the provider subscription status has changed from a non-active to an active status type /// If the previous status is already active(active,past-due,trialing),canceled,or null, then this will return false. diff --git a/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs b/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs index 9ebe09ebcc..5dce52d907 100644 --- a/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs @@ -1,8 +1,24 @@ -using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Azure.Messaging.ServiceBus; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Models.Teams; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.Services.NoopImplementations; +using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; +using TableStorageRepos = Bit.Core.Repositories.TableStorage; namespace Microsoft.Extensions.DependencyInjection; @@ -20,8 +36,467 @@ public static class EventIntegrationsServiceCollectionExtensions // This is idempotent for the same named cache, so it's safe to call. services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + // Add Validator + services.TryAddSingleton(); + // Add all commands/queries services.AddOrganizationIntegrationCommandsQueries(); + services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + return services; + } + + /// + /// Registers event write services based on available configuration. + /// + /// The service collection to add services to. + /// The global settings containing event logging configuration. + /// The service collection for chaining. + /// + /// + /// This method registers the appropriate IEventWriteService implementation based on the available + /// configuration, checking in the following priority order: + /// + /// + /// 1. Azure Service Bus - If all Azure Service Bus settings are present, registers + /// EventIntegrationEventWriteService with AzureServiceBusService as the publisher + /// + /// + /// 2. RabbitMQ - If all RabbitMQ settings are present, registers EventIntegrationEventWriteService with + /// RabbitMqService as the publisher + /// + /// + /// 3. Azure Queue Storage - If Events.ConnectionString is present, registers AzureQueueEventWriteService + /// + /// + /// 4. Repository (Self-Hosted) - If SelfHosted is true, registers RepositoryEventWriteService + /// + /// + /// 5. Noop - If none of the above are configured, registers NoopEventWriteService (no-op implementation) + /// + /// + public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) + { + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.Events.QueueName)) + { + services.TryAddSingleton(); + return services; + } + + if (globalSettings.SelfHosted) + { + services.TryAddSingleton(); + return services; + } + + services.TryAddSingleton(); + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing Azure Service Bus configuration. + /// The service collection for chaining. + /// + /// + /// If Azure Service Bus is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When Azure Service Bus is enabled, this method registers: + /// - IAzureServiceBusService and IEventIntegrationPublisher implementations + /// - Table Storage event repository + /// - Azure Table Storage event handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsAzureServiceBusEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing RabbitMQ configuration. + /// The service collection for chaining. + /// + /// + /// If RabbitMQ is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When RabbitMQ is enabled, this method registers: + /// - IRabbitMqService and IEventIntegrationPublisher implementations + /// - Event repository handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsRabbitMqEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers Slack integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Slack configuration. + /// The service collection for chaining. + /// + /// If all required Slack settings are configured (ClientId, ClientSecret, Scopes), registers the full SlackService, + /// including an HttpClient for Slack API calls. Otherwise, registers a NoopSlackService that performs no operations. + /// + public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) + { + services.AddHttpClient(SlackService.HttpClientName); + services.TryAddSingleton(); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers Microsoft Teams integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Teams configuration. + /// The service collection for chaining. + /// + /// If all required Teams settings are configured (ClientId, ClientSecret, Scopes), registers: + /// - TeamsService and its interfaces (IBot, ITeamsService) + /// - IBotFrameworkHttpAdapter with Teams credentials + /// - HttpClient for Teams API calls + /// Otherwise, registers a NoopTeamsService that performs no operations. + /// + public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) + { + services.AddHttpClient(TeamsService.HttpClientName); + services.TryAddSingleton(); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(_ => + new BotFrameworkHttpAdapter( + new TeamsBotCredentialProvider( + clientId: globalSettings.Teams.ClientId, + clientSecret: globalSettings.Teams.ClientSecret + ) + ) + ); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers event integration services including handlers, listeners, and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing integration configuration. + /// The service collection for chaining. + /// + /// + /// This method orchestrates the registration of all event integration components based on the enabled + /// message broker (Azure Service Bus or RabbitMQ). It is an internal method called by the public + /// entry points AddAzureServiceBusListeners and AddRabbitMqListeners. + /// + /// + /// NOTE: If both Azure Service Bus and RabbitMQ are configured, Azure Service Bus takes precedence. This means that + /// Azure Service Bus listeners will be registered (and RabbitMQ listeners will NOT) even if this event is called + /// from AddRabbitMqListeners when Azure Service Bus settings are configured. + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before invoking this method. + /// This method depends on distributed cache infrastructure being available for the keyed extended + /// cache registration. + /// + /// + /// Registered Services: + /// - Keyed ExtendedCache for event integrations + /// - Integration filter service + /// - Integration handlers for Slack, Webhook, Hec, Datadog, and Teams + /// - Hosted services for event and integration listeners (based on enabled message broker) + /// + /// + internal static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, + GlobalSettings globalSettings) + { + // Add common services + // NOTE: AddDistributedCache must be called by the caller before this method + services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + + // Add services in support of handlers + services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); + services.TryAddSingleton(TimeProvider.System); + services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); + services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); + + // Add integration handlers + services.TryAddSingleton, SlackIntegrationHandler>(); + services.TryAddSingleton, WebhookIntegrationHandler>(); + services.TryAddSingleton, DatadogIntegrationHandler>(); + services.TryAddSingleton, TeamsIntegrationHandler>(); + + var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); + var slackConfiguration = new SlackListenerConfiguration(globalSettings); + var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); + var hecConfiguration = new HecListenerConfiguration(globalSettings); + var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); + var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); + + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: repositoryConfiguration, + handler: provider.GetRequiredService(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = repositoryConfiguration.EventPrefetchCount, + MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddAzureServiceBusIntegration(slackConfiguration); + services.AddAzureServiceBusIntegration(webhookConfiguration); + services.AddAzureServiceBusIntegration(hecConfiguration); + services.AddAzureServiceBusIntegration(datadogConfiguration); + services.AddAzureServiceBusIntegration(teamsConfiguration); + + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredService(), + configuration: repositoryConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddRabbitMqIntegration(slackConfiguration); + services.AddRabbitMqIntegration(webhookConfiguration); + services.AddRabbitMqIntegration(hecConfiguration); + services.AddRabbitMqIntegration(datadogConfiguration); + services.AddRabbitMqIntegration(teamsConfiguration); + } + + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. AzureServiceBusEventListenerService - Hosted service for listening to event messages from Azure Service Bus + /// for this integration type + /// 3. AzureServiceBusIntegrationListenerService - Hosted service for listening to integration messages from + /// Azure Service Bus for this integration type + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + /// Service Bus processor options (PrefetchCount and MaxConcurrentCalls) are configured from the listener + /// configuration to optimize message throughput and concurrency. + /// + /// + internal static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.EventPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusIntegrationListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredService>(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. RabbitMqEventListenerService - Hosted service for listening to event messages from RabbitMQ for + /// this integration type + /// 3. RabbitMqIntegrationListenerService - Hosted service for listening to integration messages from RabbitMQ for + /// this integration type + /// + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + internal static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqIntegrationListenerService( + handler: provider.GetRequiredService>(), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService(), + timeProvider: provider.GetRequiredService() + ) + ) + ); return services; } @@ -35,4 +510,50 @@ public static class EventIntegrationsServiceCollectionExtensions return services; } + + internal static IServiceCollection AddOrganizationIntegrationConfigurationCommandsQueries(this IServiceCollection services) + { + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + + return services; + } + + /// + /// Determines if RabbitMQ is enabled for event integrations based on configuration settings. + /// + /// The global settings containing RabbitMQ configuration. + /// True if all required RabbitMQ settings are present; otherwise, false. + /// + /// Requires all the following settings to be configured: + /// - EventLogging.RabbitMq.HostName + /// - EventLogging.RabbitMq.Username + /// - EventLogging.RabbitMq.Password + /// - EventLogging.RabbitMq.EventExchangeName + /// + internal static bool IsRabbitMqEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName); + } + + /// + /// Determines if Azure Service Bus is enabled for event integrations based on configuration settings. + /// + /// The global settings containing Azure Service Bus configuration. + /// True if all required Azure Service Bus settings are present; otherwise, false. + /// + /// Requires both of the following settings to be configured: + /// - EventLogging.AzureServiceBus.ConnectionString + /// - EventLogging.AzureServiceBus.EventTopicName + /// + internal static bool IsAzureServiceBusEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName); + } } diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..cb3ce8b9ea --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,64 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for creating organization integration configurations with validation and cache invalidation support. +/// +public class CreateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : ICreateOrganizationIntegrationConfigurationCommand +{ + public async Task CreateAsync( + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, configuration)) + { + throw new BadRequestException( + $"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + var created = await configurationRepository.CreateAsync(configuration); + + // Invalidate the cached configuration details + // Even though this is a new record, the cache could hold a stale empty list for this + if (created.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: created.EventType.Value + )); + } + + return created; + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..78768fd0d4 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,54 @@ +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for deleting organization integration configurations with cache invalidation support. +/// +public class DeleteOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache) + : IDeleteOrganizationIntegrationConfigurationCommand +{ + public async Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + + await configurationRepository.DeleteAsync(configuration); + + if (configuration.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + } + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..a2078c3c98 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,29 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Query implementation for retrieving organization integration configurations. +/// +public class GetOrganizationIntegrationConfigurationsQuery( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository) + : IGetOrganizationIntegrationConfigurationsQuery +{ + public async Task> GetManyByIntegrationAsync( + Guid organizationId, + Guid integrationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + + var configurations = await configurationRepository.GetManyByIntegrationAsync(integrationId); + return configurations.ToList(); + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..140cc79d1a --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,22 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for creating organization integration configurations. +/// +public interface ICreateOrganizationIntegrationConfigurationCommand +{ + /// + /// Creates a new configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The configuration to create. + /// The created configuration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task CreateAsync(Guid organizationId, Guid integrationId, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..3970676d40 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,19 @@ +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for deleting organization integration configurations. +/// +public interface IDeleteOrganizationIntegrationConfigurationCommand +{ + /// + /// Deletes a configuration from an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to delete. + /// + /// Thrown when the integration or configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..2bf806c458 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,19 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Query interface for retrieving organization integration configurations. +/// +public interface IGetOrganizationIntegrationConfigurationsQuery +{ + /// + /// Retrieves all configurations for a specific organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// A list of configurations associated with the integration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + Task> GetManyByIntegrationAsync(Guid organizationId, Guid integrationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..3e60a0af07 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,25 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for updating organization integration configurations. +/// +public interface IUpdateOrganizationIntegrationConfigurationCommand +{ + /// + /// Updates an existing configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to update. + /// The updated configuration data. + /// The updated configuration. + /// + /// Thrown when the integration or the configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task UpdateAsync(Guid organizationId, Guid integrationId, Guid configurationId, OrganizationIntegrationConfiguration updatedConfiguration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..f619e2ddf2 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,82 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for updating organization integration configurations with validation and cache invalidation support. +/// +public class UpdateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : IUpdateOrganizationIntegrationConfigurationCommand +{ + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, updatedConfiguration)) + { + throw new BadRequestException($"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + updatedConfiguration.Id = configuration.Id; + updatedConfiguration.CreationDate = configuration.CreationDate; + await configurationRepository.ReplaceAsync(updatedConfiguration); + + // If either old or new EventType is null (wildcard), invalidate all cached results + // for the specific integration + if (configuration.EventType == null || updatedConfiguration.EventType == null) + { + // Wildcard involved - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + + return updatedConfiguration; + } + + // Both are specific event types - invalidate specific cache entries + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + + // If event type changed, also clear the new event type's cache + if (configuration.EventType != updatedConfiguration.EventType) + { + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: updatedConfiguration.EventType.Value + )); + } + + return updatedConfiguration; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs index a78dd95260..b9bad6a346 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -18,7 +19,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IGroupRepository _groupRepository; private readonly IEventService _eventService; private readonly IOrganizationService _organizationService; @@ -27,7 +28,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA public ImportOrganizationUsersAndGroupsCommand(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IGroupRepository groupRepository, IEventService eventService, IOrganizationService organizationService) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs index f8bd988cab..2648a2e429 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs @@ -2,10 +2,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; using Bit.Core.AdminConsole.Utilities.Errors; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; @@ -15,7 +15,7 @@ public class InviteOrganizationUsersValidator( IOrganizationRepository organizationRepository, IInviteUsersPasswordManagerValidator inviteUsersPasswordManagerValidator, IUpdateSecretsManagerSubscriptionCommand secretsManagerSubscriptionCommand, - IPaymentService paymentService) : IInviteUsersValidator + IStripePaymentService paymentService) : IInviteUsersValidator { public async Task> ValidateAsync( InviteOrganizationUsersValidationRequest request) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs index 67155fe91a..9ba2fd1596 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs @@ -9,8 +9,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; @@ -22,7 +22,7 @@ public class InviteUsersPasswordManagerValidator( IInviteUsersEnvironmentValidator inviteUsersEnvironmentValidator, IInviteUsersOrganizationValidator inviteUsersOrganizationValidator, IProviderRepository providerRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository ) : IInviteUsersPasswordManagerValidator { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs index 0cae0fcc81..154c3b7319 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs @@ -8,6 +8,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -33,7 +34,7 @@ public interface ICloudOrganizationSignUpCommand public class CloudOrganizationSignUpCommand( IOrganizationUserRepository organizationUserRepository, IOrganizationBillingService organizationBillingService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyService policyService, IOrganizationRepository organizationRepository, IOrganizationApiKeyRepository organizationApiKeyRepository, diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs index 6a81130402..f73c49c811 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,13 +13,13 @@ public class OrganizationDeleteCommand : IOrganizationDeleteCommand { private readonly IApplicationCacheService _applicationCacheService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; public OrganizationDeleteCommand( IApplicationCacheService applicationCacheService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ISsoConfigRepository ssoConfigRepository) { _applicationCacheService = applicationCacheService; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs index 446d7339ca..82260aa6a7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -39,7 +40,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; public ResellerClientOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -48,7 +49,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati IOrganizationUserRepository organizationUserRepository, IEventService eventService, ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, - IPaymentService paymentService) + IStripePaymentService paymentService) { _organizationRepository = organizationRepository; _organizationApiKeyRepository = organizationApiKeyRepository; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs index c52b7c10c9..6a7d068ae1 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs @@ -30,7 +30,7 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp private readonly ILicensingService _licensingService; private readonly IPolicyService _policyService; private readonly IGlobalSettings _globalSettings; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; public SelfHostedOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -44,7 +44,7 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp ILicensingService licensingService, IPolicyService policyService, IGlobalSettings globalSettings, - IPaymentService paymentService) + IStripePaymentService paymentService) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs index 450f425bdf..e4d5a94c4c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs @@ -1,12 +1,12 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Microsoft.Extensions.Logging; namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations; -public class UpdateOrganizationSubscriptionCommand(IPaymentService paymentService, +public class UpdateOrganizationSubscriptionCommand(IStripePaymentService paymentService, IOrganizationRepository repository, TimeProvider timeProvider, ILogger logger) : IUpdateOrganizationSubscriptionCommand diff --git a/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs b/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..48346cbae7 --- /dev/null +++ b/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Services; + +public interface IOrganizationIntegrationConfigurationValidator +{ + /// + /// Validates that the configuration is valid for the given integration type. The configuration must + /// include a Configuration that is valid for the type, valid Filters, and a non-empty Template + /// to pass validation. + /// + /// The type of integration + /// The OrganizationIntegrationConfiguration to validate + /// True if valid, false otherwise + bool ValidateConfiguration(IntegrationType integrationType, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs b/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs index f81175f7b5..4f48b64b5a 100644 --- a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs +++ b/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs @@ -8,7 +8,7 @@ namespace Bit.Core.Services; public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService { public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Events.ConnectionString, "event"), + new QueueClient(globalSettings.Events.ConnectionString, globalSettings.Events.QueueName), JsonHelpers.IgnoreWritingNull) { } diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index f18ecf341b..e1fcbb970d 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -21,6 +21,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -47,7 +48,7 @@ public class OrganizationService : IOrganizationService private readonly IPushNotificationService _pushNotificationService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly ISsoUserRepository _ssoUserRepository; @@ -74,7 +75,7 @@ public class OrganizationService : IOrganizationService IPushNotificationService pushNotificationService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, ISsoUserRepository ssoUserRepository, @@ -358,7 +359,7 @@ public class OrganizationService : IOrganizationService { var newDisplayName = organization.DisplayName(); - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail, diff --git a/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs b/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..2769565675 --- /dev/null +++ b/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,76 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Services; + +public class OrganizationIntegrationConfigurationValidator : IOrganizationIntegrationConfigurationValidator +{ + public bool ValidateConfiguration(IntegrationType integrationType, + OrganizationIntegrationConfiguration configuration) + { + // Validate template is present + if (string.IsNullOrWhiteSpace(configuration.Template)) + { + return false; + } + // If Filters are present, they must be valid + if (!IsFiltersValid(configuration.Filters)) + { + return false; + } + + switch (integrationType) + { + case IntegrationType.CloudBillingSync or IntegrationType.Scim: + return false; + case IntegrationType.Slack: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Webhook: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Hec: + case IntegrationType.Datadog: + case IntegrationType.Teams: + return configuration.Configuration is null; + default: + return false; + } + } + + private static bool IsConfigurationValid(string? configuration) + { + if (string.IsNullOrWhiteSpace(configuration)) + { + return false; + } + + try + { + var config = JsonSerializer.Deserialize(configuration); + return config is not null; + } + catch + { + return false; + } + } + + private static bool IsFiltersValid(string? filters) + { + if (filters is null) + { + return true; + } + + try + { + var filterGroup = JsonSerializer.Deserialize(filters); + return filterGroup is not null; + } + catch + { + return false; + } + } +} diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs index 143da0d67f..2a5e786c98 100644 --- a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -7,8 +7,8 @@ using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; -using Bit.Core.Services; using Microsoft.Extensions.Logging; using OneOf; using Stripe; @@ -125,7 +125,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); @@ -165,7 +165,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } else @@ -181,7 +181,7 @@ public class PreviewOrganizationTaxCommand( var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer"] }); if (subscription.Customer.Discount != null) @@ -259,7 +259,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } }); @@ -278,7 +278,7 @@ public class PreviewOrganizationTaxCommand( return new BadRequest("Organization does not have a subscription."); } - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.tax_ids"] }); var options = GetBaseOptions(subscription.Customer, @@ -336,7 +336,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); diff --git a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs index f00bc00356..a8a236decc 100644 --- a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs @@ -22,14 +22,14 @@ public interface IGetCloudOrganizationLicenseQuery public class GetCloudOrganizationLicenseQuery : IGetCloudOrganizationLicenseQuery { private readonly IInstallationRepository _installationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ILicensingService _licensingService; private readonly IProviderRepository _providerRepository; private readonly IFeatureService _featureService; public GetCloudOrganizationLicenseQuery( IInstallationRepository installationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ILicensingService licensingService, IProviderRepository providerRepository, IFeatureService featureService) diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs index 01e520ea41..af8dfa7aec 100644 --- a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs @@ -9,7 +9,6 @@ using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -201,7 +200,7 @@ public class GetOrganizationWarningsQuery( // ReSharper disable once InvertIf if (subscription.Status == SubscriptionStatus.PastDue) { - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -257,8 +256,8 @@ public class GetOrganizationWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index 65c339fad4..162fb488f6 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -14,7 +14,6 @@ using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -161,7 +160,7 @@ public class OrganizationBillingService( try { // Update the subscription in Stripe - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, updateOptions); organization.PlanType = newPlan.Type; await organizationRepository.ReplaceAsync(organization); } @@ -185,7 +184,7 @@ public class OrganizationBillingService( var newDisplayName = organization.DisplayName(); - await stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail, @@ -324,7 +323,7 @@ public class OrganizationBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -358,7 +357,7 @@ public class OrganizationBillingService( try { - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); organization.Gateway = GatewayType.Stripe; organization.GatewayCustomerId = customer.Id; @@ -509,7 +508,7 @@ public class OrganizationBillingService( subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; await organizationRepository.ReplaceAsync(organization); @@ -537,14 +536,14 @@ public class OrganizationBillingService( customer = customer switch { { Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, TaxExempt = StripeConstants.TaxExempt.Reverse }), { Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, @@ -603,7 +602,7 @@ public class OrganizationBillingService( } } }; - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, options); + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, options); } } diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index f4eca40cae..daf39fb981 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Microsoft.Extensions.Logging; using Stripe; @@ -46,7 +45,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -71,7 +70,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -92,7 +91,7 @@ public class UpdateBillingAddressCommand( await EnableAutomaticTaxAsync(subscriber, customer); var deleteExistingTaxIds = customer.TaxIds?.Any() ?? false - ? customer.TaxIds.Select(taxId => stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id)).ToList() + ? customer.TaxIds.Select(taxId => stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id)).ToList() : []; if (billingAddress.TaxId == null) @@ -101,12 +100,12 @@ public class UpdateBillingAddressCommand( return BillingAddress.From(customer.Address); } - var updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + var updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }); if (billingAddress.TaxId.Code == StripeConstants.TaxIdType.SpanishNIF) { - updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, @@ -130,7 +129,7 @@ public class UpdateBillingAddressCommand( if (subscription is { AutomaticTax.Enabled: false }) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs index 81206b8032..a5a9e3e9c9 100644 --- a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -56,7 +55,7 @@ public class UpdatePaymentMethodCommand( if (billingAddress != null && customer.Address is not { Country: not null, PostalCode: not null }) { - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions @@ -75,7 +74,7 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var setupIntents = await stripeAdapter.SetupIntentList(new SetupIntentListOptions + var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { Expand = ["data.payment_method"], PaymentMethod = token @@ -104,9 +103,9 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); + var paymentMethod = await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { DefaultPaymentMethod = token } @@ -139,7 +138,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomer.Id }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } var payPalAccount = braintreeCustomer.DefaultPaymentMethod as PayPalAccount; @@ -204,7 +203,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = string.Empty }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } } } diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs index 9f9618571e..e03a785278 100644 --- a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; using Stripe; @@ -53,7 +52,7 @@ public class GetPaymentMethodQuery( if (!string.IsNullOrEmpty(setupIntentId)) { - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); diff --git a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs index ec77ee0712..c972c3fe5f 100644 --- a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Stripe; namespace Bit.Core.Billing.Payment.Queries; @@ -48,7 +47,7 @@ public class HasPaymentMethodQuery( return false; } - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 472f31ac4b..ed60e2f11c 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -210,7 +210,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -243,7 +243,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch { @@ -300,7 +300,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( ValidateLocation = ValidateTaxLocationTiming.Immediately } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } private async Task CreateSubscriptionAsync( @@ -349,11 +349,11 @@ public class CreatePremiumCloudHostedSubscriptionCommand( OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs index 5f09b8b77b..07247c83cb 100644 --- a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using Stripe; @@ -56,7 +56,7 @@ public class PreviewPremiumTaxCommand( }); } - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs new file mode 100644 index 0000000000..5ec732920e --- /dev/null +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -0,0 +1,50 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; + +namespace Bit.Core.Billing.Services; + +public interface IStripeAdapter +{ + Task CreateCustomerAsync(CustomerCreateOptions customerCreateOptions); + Task GetCustomerAsync(string id, CustomerGetOptions options = null); + Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null); + Task DeleteCustomerAsync(string id); + Task> ListCustomerPaymentMethodsAsync(string id, CustomerPaymentMethodListOptions options = null); + Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options); + Task CreateSubscriptionAsync(SubscriptionCreateOptions subscriptionCreateOptions); + Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null); + Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null); + Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); + Task UpdateSubscriptionAsync(string id, SubscriptionUpdateOptions options = null); + Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); + Task GetInvoiceAsync(string id, InvoiceGetOptions options); + Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); + Task> SearchInvoiceAsync(InvoiceSearchOptions options); + Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); + Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); + Task SendInvoiceAsync(string id, InvoiceSendOptions options); + Task PayInvoiceAsync(string id, InvoicePayOptions options = null); + Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null); + Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null); + IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options); + IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options); + Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null); + Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null); + Task CreateTaxIdAsync(string id, TaxIdCreateOptions options); + Task DeleteTaxIdAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); + Task> ListChargesAsync(ChargeListOptions options); + Task CreateRefundAsync(RefundCreateOptions options); + Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null); + Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); + Task CreateSetupIntentAsync(SetupIntentCreateOptions options); + Task> ListSetupIntentsAsync(SetupIntentListOptions options); + Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null); + Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null); + Task GetPriceAsync(string id, PriceGetOptions options = null); +} diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Billing/Services/IStripePaymentService.cs similarity index 97% rename from src/Core/Services/IPaymentService.cs rename to src/Core/Billing/Services/IStripePaymentService.cs index b4a4639992..b948cf6921 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Billing/Services/IStripePaymentService.cs @@ -8,9 +8,9 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services; -public interface IPaymentService +public interface IStripePaymentService { Task CancelAndRecoverChargesAsync(ISubscriber subscriber); Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); diff --git a/src/Core/Billing/Services/IStripeSyncService.cs b/src/Core/Billing/Services/IStripeSyncService.cs new file mode 100644 index 0000000000..b56204cd47 --- /dev/null +++ b/src/Core/Billing/Services/IStripeSyncService.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Billing.Services; + +public interface IStripeSyncService +{ + Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress); +} diff --git a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs index 5a8cf16f5a..16b3f7e0c3 100644 --- a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs +++ b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.Billing.Services.Implementations; @@ -23,7 +22,7 @@ public class PaymentHistoryService( return Array.Empty(); } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = subscriber.GatewayCustomerId, Limit = pageSize, diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index daa06b907a..9c85971dff 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -12,7 +12,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -68,7 +67,7 @@ public class PremiumUserBillingService( } }; - customer = await stripeAdapter.CustomerCreateAsync(options); + customer = await stripeAdapter.CreateCustomerAsync(options); user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; @@ -81,7 +80,7 @@ public class PremiumUserBillingService( Balance = customer.Balance + credit }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } @@ -227,7 +226,7 @@ public class PremiumUserBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -260,7 +259,7 @@ public class PremiumUserBillingService( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) @@ -347,11 +346,11 @@ public class PremiumUserBillingService( OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); @@ -387,6 +386,6 @@ public class PremiumUserBillingService( } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs new file mode 100644 index 0000000000..cdc7645042 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -0,0 +1,209 @@ +// FIXME: Update this file to be null safe and then delete the line below + +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; +using Stripe.TestHelpers; +using CustomerService = Stripe.CustomerService; +using RefundService = Stripe.RefundService; + +namespace Bit.Core.Billing.Services.Implementations; + +public class StripeAdapter : IStripeAdapter +{ + private readonly CustomerService _customerService; + private readonly SubscriptionService _subscriptionService; + private readonly InvoiceService _invoiceService; + private readonly PaymentMethodService _paymentMethodService; + private readonly TaxIdService _taxIdService; + private readonly ChargeService _chargeService; + private readonly RefundService _refundService; + private readonly CardService _cardService; + private readonly BankAccountService _bankAccountService; + private readonly PriceService _priceService; + private readonly SetupIntentService _setupIntentService; + private readonly TestClockService _testClockService; + private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; + private readonly RegistrationService _taxRegistrationService; + + public StripeAdapter() + { + _customerService = new CustomerService(); + _subscriptionService = new SubscriptionService(); + _invoiceService = new InvoiceService(); + _paymentMethodService = new PaymentMethodService(); + _taxIdService = new TaxIdService(); + _chargeService = new ChargeService(); + _refundService = new RefundService(); + _cardService = new CardService(); + _bankAccountService = new BankAccountService(); + _priceService = new PriceService(); + _setupIntentService = new SetupIntentService(); + _testClockService = new TestClockService(); + _customerBalanceTransactionService = new CustomerBalanceTransactionService(); + _taxRegistrationService = new RegistrationService(); + } + + /************** + ** CUSTOMER ** + **************/ + public Task CreateCustomerAsync(CustomerCreateOptions options) => + _customerService.CreateAsync(options); + + public Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => + _customerService.DeleteDiscountAsync(customerId, options); + + public Task GetCustomerAsync(string id, CustomerGetOptions options = null) => + _customerService.GetAsync(id, options); + + public Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null) => + _customerService.UpdateAsync(id, options); + + public Task DeleteCustomerAsync(string id) => + _customerService.DeleteAsync(id); + + public async Task> ListCustomerPaymentMethodsAsync(string id, + CustomerPaymentMethodListOptions options = null) + { + var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); + return paymentMethods.Data; + } + + public Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options) => + _customerBalanceTransactionService.CreateAsync(customerId, options); + + /****************** + ** SUBSCRIPTION ** + ******************/ + public Task CreateSubscriptionAsync(SubscriptionCreateOptions options) => + _subscriptionService.CreateAsync(options); + + public Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null) => + _subscriptionService.GetAsync(id, options); + + public Task UpdateSubscriptionAsync(string id, + SubscriptionUpdateOptions options = null) => + _subscriptionService.UpdateAsync(id, options); + + public Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null) => + _subscriptionService.CancelAsync(id, options); + + /************* + ** INVOICE ** + *************/ + public Task GetInvoiceAsync(string id, InvoiceGetOptions options) => + _invoiceService.GetAsync(id, options); + + public async Task> ListInvoicesAsync(StripeInvoiceListOptions options) + { + if (!options.SelectAll) + { + return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; + } + + options.Limit = 100; + + var invoices = new List(); + + await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) + { + invoices.Add(invoice); + } + + return invoices; + } + + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => + _invoiceService.CreatePreviewAsync(options); + + public async Task> SearchInvoiceAsync(InvoiceSearchOptions options) => + (await _invoiceService.SearchAsync(options)).Data; + + public Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options) => + _invoiceService.UpdateAsync(id, options); + + public Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) => + _invoiceService.FinalizeInvoiceAsync(id, options); + + public Task SendInvoiceAsync(string id, InvoiceSendOptions options) => + _invoiceService.SendInvoiceAsync(id, options); + + public Task PayInvoiceAsync(string id, InvoicePayOptions options = null) => + _invoiceService.PayAsync(id, options); + + public Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null) => + _invoiceService.DeleteAsync(id, options); + + public Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null) => + _invoiceService.VoidInvoiceAsync(id, options); + + /******************** + ** PAYMENT METHOD ** + ********************/ + public IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options) => + _paymentMethodService.ListAutoPaging(options); + + public IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options) + => _paymentMethodService.ListAutoPagingAsync(options); + + public Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null) => + _paymentMethodService.AttachAsync(id, options); + + public Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null) => + _paymentMethodService.DetachAsync(id, options); + + /************ + ** TAX ID ** + ************/ + public Task CreateTaxIdAsync(string id, TaxIdCreateOptions options) => + _taxIdService.CreateAsync(id, options); + + public Task DeleteTaxIdAsync(string customerId, string taxIdId, + TaxIdDeleteOptions options = null) => + _taxIdService.DeleteAsync(customerId, taxIdId, options); + + /****************** + ** BANK ACCOUNT ** + ******************/ + public Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) => + _bankAccountService.DeleteAsync(customerId, bankAccount, options); + + /*********** + ** PRICE ** + ***********/ + public Task GetPriceAsync(string id, PriceGetOptions options = null) => + _priceService.GetAsync(id, options); + + /****************** + ** SETUP INTENT ** + ******************/ + public Task CreateSetupIntentAsync(SetupIntentCreateOptions options) => + _setupIntentService.CreateAsync(options); + + public async Task> ListSetupIntentsAsync(SetupIntentListOptions options) => + (await _setupIntentService.ListAsync(options)).Data; + + public Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null) => + _setupIntentService.CancelAsync(id, options); + + public Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null) => + _setupIntentService.GetAsync(id, options); + + /******************* + ** MISCELLANEOUS ** + *******************/ + public Task> ListChargesAsync(ChargeListOptions options) => + _chargeService.ListAsync(options); + + public Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null) => + _taxRegistrationService.ListAsync(options); + + public Task CreateRefundAsync(RefundCreateOptions options) => + _refundService.CreateAsync(options); + + public Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null) => + _cardService.DeleteAsync(customerId, cardId, options); +} diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Billing/Services/Implementations/StripePaymentService.cs similarity index 90% rename from src/Core/Services/Implementations/StripePaymentService.cs rename to src/Core/Billing/Services/Implementations/StripePaymentService.cs index c887a388bd..ffc18aa748 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Billing/Services/Implementations/StripePaymentService.cs @@ -21,9 +21,9 @@ using Stripe; using PaymentMethod = Stripe.PaymentMethod; using StaticStore = Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; -public class StripePaymentService : IPaymentService +public class StripePaymentService : IStripePaymentService { private const string SecretsManagerStandaloneDiscountId = "sm-standalone"; @@ -64,7 +64,7 @@ public class StripePaymentService : IPaymentService await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true); - var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(org.GatewaySubscriptionId); org.ExpirationDate = sub.GetCurrentPeriodEnd(); if (sponsorship is not null) @@ -84,7 +84,7 @@ public class StripePaymentService : IPaymentService { // remember, when in doubt, throw var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] }; - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { throw new GatewayException("Subscription not found."); @@ -107,7 +107,7 @@ public class StripePaymentService : IPaymentService var subUpdateOptions = new SubscriptionUpdateOptions { Items = updatedItemOptions, - ProrationBehavior = invoiceNow ? Constants.AlwaysInvoice : Constants.CreateProrations, + ProrationBehavior = invoiceNow ? Core.Constants.AlwaysInvoice : Core.Constants.CreateProrations, DaysUntilDue = daysUntilDue ?? 1, CollectionMethod = "send_invoice" }; @@ -121,11 +121,11 @@ public class StripePaymentService : IPaymentService { if (sub.Customer is { - Address.Country: not Constants.CountryAbbreviations.UnitedStates, + Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse }) { - await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, + await _stripeAdapter.UpdateCustomerAsync(sub.CustomerId, new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); } @@ -141,9 +141,9 @@ public class StripePaymentService : IPaymentService string paymentIntentClientSecret = null; try { - var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); + var subResponse = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, subUpdateOptions); - var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); + var invoice = await _stripeAdapter.GetInvoiceAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); if (invoice == null) { throw new BadRequestException("Unable to locate draft invoice for subscription update."); @@ -162,9 +162,9 @@ public class StripePaymentService : IPaymentService } else { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, + invoice = await _stripeAdapter.FinalizeInvoiceAsync(subResponse.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = false, }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); + await _stripeAdapter.SendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); paymentIntentClientSecret = null; } } @@ -172,7 +172,7 @@ public class StripePaymentService : IPaymentService catch { // Need to revert the subscription - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { Items = subscriptionUpdate.RevertItemsOptions(sub), // This proration behavior prevents a false "credit" from @@ -187,7 +187,7 @@ public class StripePaymentService : IPaymentService else if (invoice.Status != StripeConstants.InvoiceStatus.Paid) { // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h - invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); + invoice = await _stripeAdapter.PayInvoiceAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } } @@ -196,7 +196,7 @@ public class StripePaymentService : IPaymentService // Change back the subscription collection method and/or days until due if (collectionMethod != "send_invoice" || daysUntilDue == null) { - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CollectionMethod = collectionMethod, @@ -204,14 +204,14 @@ public class StripePaymentService : IPaymentService }); } - var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(sub.CustomerId); var newCoupon = customer.Discount?.Coupon?.Id; if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon)) { // Re-add the lost coupon due to the update. - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { Discounts = [ @@ -284,7 +284,7 @@ public class StripePaymentService : IPaymentService { if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, + await _stripeAdapter.CancelSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionCancelOptions()); } @@ -293,7 +293,7 @@ public class StripePaymentService : IPaymentService return; } - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); if (customer == null) { return; @@ -318,7 +318,7 @@ public class StripePaymentService : IPaymentService } else { - var charges = await _stripeAdapter.ChargeListAsync(new ChargeListOptions + var charges = await _stripeAdapter.ListChargesAsync(new ChargeListOptions { Customer = subscriber.GatewayCustomerId }); @@ -327,12 +327,12 @@ public class StripePaymentService : IPaymentService { foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) { - await _stripeAdapter.RefundCreateAsync(new RefundCreateOptions { Charge = charge.Id }); + await _stripeAdapter.CreateRefundAsync(new RefundCreateOptions { Charge = charge.Id }); } } } - await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); + await _stripeAdapter.DeleteCustomerAsync(subscriber.GatewayCustomerId); } public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Invoice invoice) @@ -340,7 +340,7 @@ public class StripePaymentService : IPaymentService var customerOptions = new CustomerGetOptions(); customerOptions.AddExpand("default_source"); customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); string paymentIntentClientSecret = null; @@ -360,13 +360,13 @@ public class StripePaymentService : IPaymentService // We're going to delete this draft invoice, it can't be paid try { - await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + await _stripeAdapter.DeleteInvoiceAsync(invoice.Id); } catch { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions { AutoAdvance = false }); - await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + await _stripeAdapter.VoidInvoiceAsync(invoice.Id); } throw new BadRequestException("No payment method is available."); @@ -379,7 +379,7 @@ public class StripePaymentService : IPaymentService { // Finalize the invoice (from Draft) w/o auto-advance so we // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + invoice = await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions { AutoAdvance = false, }); var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, }; if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) @@ -414,7 +414,7 @@ public class StripePaymentService : IPaymentService } braintreeTransaction = transactionResult.Target; - invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new InvoiceUpdateOptions + invoice = await _stripeAdapter.UpdateInvoiceAsync(invoice.Id, new InvoiceUpdateOptions { Metadata = new Dictionary { @@ -428,7 +428,7 @@ public class StripePaymentService : IPaymentService try { - invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); + invoice = await _stripeAdapter.PayInvoiceAsync(invoice.Id, invoicePayOptions); } catch (StripeException e) { @@ -438,7 +438,7 @@ public class StripePaymentService : IPaymentService // SCA required, get intent client secret var invoiceGetOptions = new InvoiceGetOptions(); invoiceGetOptions.AddExpand("confirmation_secret"); - invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); + invoice = await _stripeAdapter.GetInvoiceAsync(invoice.Id, invoiceGetOptions); paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret; } else @@ -462,7 +462,7 @@ public class StripePaymentService : IPaymentService return paymentIntentClientSecret; } - invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); + invoice = await _stripeAdapter.VoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); // HACK: Workaround for customer balance credit if (invoice.StartingBalance < 0) @@ -470,12 +470,12 @@ public class StripePaymentService : IPaymentService // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to // credit it back to the customer (even though their docs claim they will), we need to // check that balance against the current customer balance and determine if it needs to be re-applied - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); // Assumption: Customer balance should now be $0, otherwise payment would not have failed. if (customer.Balance == 0) { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Balance = invoice.StartingBalance }); } } @@ -506,7 +506,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -522,9 +522,9 @@ public class StripePaymentService : IPaymentService try { var canceledSub = endOfPeriod - ? await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + ? await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) - : await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions()); + : await _stripeAdapter.CancelSubscriptionAsync(sub.Id, new SubscriptionCancelOptions()); if (!canceledSub.CanceledAt.HasValue) { throw new GatewayException("Unable to cancel subscription."); @@ -551,7 +551,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -563,7 +563,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("Subscription is not marked for cancellation."); } - var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + var updatedSub = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); if (updatedSub.CanceledAt.HasValue) { @@ -578,11 +578,11 @@ public class StripePaymentService : IPaymentService !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); if (customerExists) { - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); } else { - customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions + customer = await _stripeAdapter.CreateCustomerAsync(new CustomerCreateOptions { Email = subscriber.BillingEmailAddress(), Description = subscriber.BillingName(), @@ -591,9 +591,8 @@ public class StripePaymentService : IPaymentService subscriber.GatewayCustomerId = customer.Id; } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) }); - return !customerExists; } @@ -630,7 +629,7 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, + var subscription = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.discount.coupon.applies_to", "discounts.coupon.applies_to", "test_clock"] }); if (subscription == null) @@ -675,7 +674,7 @@ public class StripePaymentService : IPaymentService Subscription = subscriber.GatewaySubscriptionId }; - var upcomingInvoice = await _stripeAdapter.InvoiceCreatePreviewAsync(invoiceCreatePreviewOptions); + var upcomingInvoice = await _stripeAdapter.CreateInvoicePreviewAsync(invoiceCreatePreviewOptions); if (upcomingInvoice != null) { @@ -726,7 +725,7 @@ public class StripePaymentService : IPaymentService return false; } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); return customer?.Discount?.Coupon?.Id == SecretsManagerStandaloneDiscountId; } @@ -738,7 +737,7 @@ public class StripePaymentService : IPaymentService return (null, null); } - var openInvoices = await _stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await _stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -774,7 +773,7 @@ public class StripePaymentService : IPaymentService private PaymentMethod GetLatestCardPaymentMethod(string customerId) { - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( + var cardPaymentMethods = _stripeAdapter.ListPaymentMethodsAutoPaging( new PaymentMethodListOptions { Customer = customerId, Type = "card" }); return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); } @@ -837,7 +836,7 @@ public class StripePaymentService : IPaymentService Customer customer = null; try { - customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); + customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId, options); } catch (StripeException) { @@ -870,21 +869,21 @@ public class StripePaymentService : IPaymentService try { - var paidInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var paidInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "paid" }); - var openInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var openInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "open" }); - var uncollectibleInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var uncollectibleInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, diff --git a/src/Core/Services/Implementations/StripeSyncService.cs b/src/Core/Billing/Services/Implementations/StripeSyncService.cs similarity index 68% rename from src/Core/Services/Implementations/StripeSyncService.cs rename to src/Core/Billing/Services/Implementations/StripeSyncService.cs index b2700e65d1..31dd89d72d 100644 --- a/src/Core/Services/Implementations/StripeSyncService.cs +++ b/src/Core/Billing/Services/Implementations/StripeSyncService.cs @@ -1,6 +1,6 @@ using Bit.Core.Exceptions; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; public class StripeSyncService : IStripeSyncService { @@ -11,7 +11,7 @@ public class StripeSyncService : IStripeSyncService _stripeAdapter = stripeAdapter; } - public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) + public async Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress) { if (string.IsNullOrWhiteSpace(gatewayCustomerId)) { @@ -23,9 +23,9 @@ public class StripeSyncService : IStripeSyncService throw new InvalidEmailException(); } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new Stripe.CustomerUpdateOptions { Email = emailAddress }); } } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 4b2ea26294..7acbe20014 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -15,7 +15,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -78,7 +77,7 @@ public class SubscriberService( { if (subscription.Metadata != null && subscription.Metadata.ContainsKey("organizationId")) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Metadata = metadata }); @@ -97,7 +96,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionCancelAsync(subscription.Id, options); + await stripeAdapter.CancelSubscriptionAsync(subscription.Id, options); } else { @@ -116,7 +115,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, options); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); } } @@ -227,7 +226,7 @@ public class SubscriberService( _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) }; - var customer = await stripeAdapter.CustomerCreateAsync(options); + var customer = await stripeAdapter.CreateCustomerAsync(options); switch (subscriber) { @@ -270,7 +269,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -306,7 +305,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -357,7 +356,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -393,7 +392,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -487,23 +486,23 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(stripeCustomer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(stripeCustomer.Id, source.Id); break; } } } - var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new PaymentMethodListOptions + var paymentMethods = stripeAdapter.ListPaymentMethodsAutoPagingAsync(new PaymentMethodListOptions { Customer = stripeCustomer.Id }); await foreach (var paymentMethod in paymentMethods) { - await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id); + await stripeAdapter.DetachPaymentMethodAsync(paymentMethod.Id); } } } @@ -532,7 +531,7 @@ public class SubscriberService( { case PaymentMethodType.BankAccount: { - var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.SetupIntentList(new SetupIntentListOptions + var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = token }); @@ -569,7 +568,7 @@ public class SubscriberService( await RemoveStripePaymentMethodsAsync(customer); // Attach the incoming payment method. - await stripeAdapter.PaymentMethodAttachAsync(token, + await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); var metadata = customer.Metadata; @@ -581,7 +580,7 @@ public class SubscriberService( } // Set the customer's default payment method in Stripe and remove their Braintree customer ID. - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -644,7 +643,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -662,7 +661,7 @@ public class SubscriberService( if (taxId != null) { - await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + await stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id); } if (!string.IsNullOrWhiteSpace(taxInformation.TaxId)) @@ -685,12 +684,12 @@ public class SubscriberService( try { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = taxIdType, Value = taxInformation.TaxId }); if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInformation.TaxId}" }); } } @@ -736,7 +735,7 @@ public class SubscriberService( Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); break; case @@ -744,14 +743,14 @@ public class SubscriberService( Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.None }); break; } if (!subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -771,7 +770,7 @@ public class SubscriberService( if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -790,7 +789,7 @@ public class SubscriberService( } try { - await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -809,7 +808,7 @@ public class SubscriberService( } try { - await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -828,7 +827,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdKey] = braintreeCustomerId; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -868,7 +867,7 @@ public class SubscriberService( return null; } - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); @@ -886,7 +885,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdOldKey] = value; metadata[BraintreeCustomerIdKey] = null; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -903,18 +902,18 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(customer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(customer.Id, source.Id); break; } } } - var paymentMethods = await stripeAdapter.CustomerListPaymentMethods(customer.Id); + var paymentMethods = await stripeAdapter.ListCustomerPaymentMethodsAsync(customer.Id); - await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.PaymentMethodDetachAsync(pm.Id))); + await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.DetachPaymentMethodAsync(pm.Id))); } private async Task ReplaceBraintreePaymentMethodAsync( diff --git a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs index ee60597601..7f7be9d1eb 100644 --- a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs +++ b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs @@ -7,7 +7,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Repositories; -using Bit.Core.Services; using OneOf.Types; using Stripe; @@ -53,7 +52,7 @@ public class RestartSubscriptionCommand( TrialPeriodDays = 0 }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(options); + var subscription = await stripeAdapter.CreateSubscriptionAsync(options); await EnableAsync(subscriber, subscription); return new None(); } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 2ee6b75664..ec5978988c 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -2,8 +2,8 @@ #nullable disable using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Services; using Stripe; namespace Bit.Core.Billing; @@ -22,7 +22,7 @@ public static class Utilities return null; } - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 3710cb4a23..6d2c2a1673 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -160,7 +160,6 @@ public static class FeatureFlagKeys public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; public const string PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword = "pm-23174-manage-account-recovery-permission-drives-the-need-to-set-master-password"; - public const string RecoveryCodeSupportForSsoRequiredUsers = "pm-21153-recovery-code-support-for-sso-required"; public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 4902d5bdbe..52c0a641ab 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -41,6 +41,7 @@ + diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs index 0aebc3fc3b..6d60f05b2a 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; @@ -13,9 +13,9 @@ public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand { private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; - public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) + public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IStripePaymentService paymentService) { _organizationSponsorshipRepository = organizationSponsorshipRepository; _organizationRepository = organizationRepository; diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs index a26d553570..4b983317c9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -14,14 +15,14 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnte public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; public ValidateSponsorshipCommand( IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) { diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs index a0ce7c03b9..25b84fe989 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; @@ -12,13 +13,13 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class AddSecretsManagerSubscriptionCommand : IAddSecretsManagerSubscriptionCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationService _organizationService; private readonly IProviderRepository _providerRepository; private readonly IPricingClient _pricingClient; public AddSecretsManagerSubscriptionCommand( - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationService organizationService, IProviderRepository providerRepository, IPricingClient pricingClient) diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs index d4e1b3cd8d..baf2616a53 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -18,7 +19,7 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubscriptionCommand { private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; private readonly IServiceAccountRepository _serviceAccountRepository; @@ -29,7 +30,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs public UpdateSecretsManagerSubscriptionCommand( IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger, IServiceAccountRepository serviceAccountRepository, diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs index b704cb0460..092ee0f46e 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs @@ -11,6 +11,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -26,7 +27,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IOrganizationConnectionRepository _organizationConnectionRepository; @@ -41,7 +42,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand IOrganizationUserRepository organizationUserRepository, ICollectionRepository collectionRepository, IGroupRepository groupRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, ISsoConfigRepository ssoConfigRepository, IOrganizationConnectionRepository organizationConnectionRepository, diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs deleted file mode 100644 index 6b2c3c299e..0000000000 --- a/src/Core/Services/IStripeAdapter.cs +++ /dev/null @@ -1,54 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; -using Stripe.Tax; - -namespace Bit.Core.Services; - -public interface IStripeAdapter -{ - Task CustomerCreateAsync(CustomerCreateOptions customerCreateOptions); - Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); - Task CustomerGetAsync(string id, CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task> CustomerListPaymentMethods(string id, CustomerPaymentMethodListOptions options = null); - Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options); - Task SubscriptionCreateAsync(SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null); - Task SubscriptionUpdateAsync(string id, SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null); - Task InvoiceGetAsync(string id, InvoiceGetOptions options); - Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); - Task> InvoiceSearchAsync(InvoiceSearchOptions options); - Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options); - Task InvoicePayAsync(string id, InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options); - IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null); - Task TaxIdCreateAsync(string id, TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); - Task> TaxRegistrationsListAsync(RegistrationListOptions options = null); - Task> ChargeListAsync(ChargeListOptions options); - Task RefundCreateAsync(RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); - Task> PriceListAsync(PriceListOptions options = null); - Task SetupIntentCreate(SetupIntentCreateOptions options); - Task> SetupIntentList(SetupIntentListOptions options); - Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null); - Task SetupIntentGet(string id, SetupIntentGetOptions options = null); - Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options); - Task> TestClockListAsync(); - Task PriceGetAsync(string id, PriceGetOptions options = null); -} diff --git a/src/Core/Services/IStripeSyncService.cs b/src/Core/Services/IStripeSyncService.cs deleted file mode 100644 index 655998805e..0000000000 --- a/src/Core/Services/IStripeSyncService.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Core.Services; - -public interface IStripeSyncService -{ - Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); -} diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs deleted file mode 100644 index 3d1663f021..0000000000 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ /dev/null @@ -1,284 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; -using Stripe.Tax; - -namespace Bit.Core.Services; - -public class StripeAdapter : IStripeAdapter -{ - private readonly CustomerService _customerService; - private readonly SubscriptionService _subscriptionService; - private readonly InvoiceService _invoiceService; - private readonly PaymentMethodService _paymentMethodService; - private readonly TaxIdService _taxIdService; - private readonly ChargeService _chargeService; - private readonly RefundService _refundService; - private readonly CardService _cardService; - private readonly BankAccountService _bankAccountService; - private readonly PlanService _planService; - private readonly PriceService _priceService; - private readonly SetupIntentService _setupIntentService; - private readonly Stripe.TestHelpers.TestClockService _testClockService; - private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; - private readonly Stripe.Tax.RegistrationService _taxRegistrationService; - private readonly CalculationService _calculationService; - - public StripeAdapter() - { - _customerService = new CustomerService(); - _subscriptionService = new SubscriptionService(); - _invoiceService = new InvoiceService(); - _paymentMethodService = new PaymentMethodService(); - _taxIdService = new TaxIdService(); - _chargeService = new ChargeService(); - _refundService = new RefundService(); - _cardService = new CardService(); - _bankAccountService = new BankAccountService(); - _priceService = new PriceService(); - _planService = new PlanService(); - _setupIntentService = new SetupIntentService(); - _testClockService = new Stripe.TestHelpers.TestClockService(); - _customerBalanceTransactionService = new CustomerBalanceTransactionService(); - _taxRegistrationService = new Stripe.Tax.RegistrationService(); - _calculationService = new CalculationService(); - } - - public Task CustomerCreateAsync(CustomerCreateOptions options) - { - return _customerService.CreateAsync(options); - } - - public Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => - _customerService.DeleteDiscountAsync(customerId, options); - - public Task CustomerGetAsync(string id, CustomerGetOptions options = null) - { - return _customerService.GetAsync(id, options); - } - - public Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null) - { - return _customerService.UpdateAsync(id, options); - } - - public Task CustomerDeleteAsync(string id) - { - return _customerService.DeleteAsync(id); - } - - public async Task> CustomerListPaymentMethods(string id, - CustomerPaymentMethodListOptions options = null) - { - var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); - return paymentMethods.Data; - } - - public async Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options) - => await _customerBalanceTransactionService.CreateAsync(customerId, options); - - public Task SubscriptionCreateAsync(SubscriptionCreateOptions options) - { - return _subscriptionService.CreateAsync(options); - } - - public Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null) - { - return _subscriptionService.GetAsync(id, options); - } - - public async Task ProviderSubscriptionGetAsync( - string id, - Guid providerId, - SubscriptionGetOptions options = null) - { - var subscription = await _subscriptionService.GetAsync(id, options); - if (subscription.Metadata.TryGetValue("providerId", out var value) && value == providerId.ToString()) - { - return subscription; - } - - throw new InvalidOperationException("Subscription does not belong to the provider."); - } - - public Task SubscriptionUpdateAsync(string id, - SubscriptionUpdateOptions options = null) - { - return _subscriptionService.UpdateAsync(id, options); - } - - public Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null) - { - return _subscriptionService.CancelAsync(Id, options); - } - - public Task InvoiceGetAsync(string id, InvoiceGetOptions options) - { - return _invoiceService.GetAsync(id, options); - } - - public async Task> InvoiceListAsync(StripeInvoiceListOptions options) - { - if (!options.SelectAll) - { - return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; - } - - options.Limit = 100; - - var invoices = new List(); - - await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) - { - invoices.Add(invoice); - } - - return invoices; - } - - public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options) - { - return _invoiceService.CreatePreviewAsync(options); - } - - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) - => (await _invoiceService.SearchAsync(options)).Data; - - public Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options) - { - return _invoiceService.UpdateAsync(id, options); - } - - public Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) - { - return _invoiceService.FinalizeInvoiceAsync(id, options); - } - - public Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options) - { - return _invoiceService.SendInvoiceAsync(id, options); - } - - public Task InvoicePayAsync(string id, InvoicePayOptions options = null) - { - return _invoiceService.PayAsync(id, options); - } - - public Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null) - { - return _invoiceService.DeleteAsync(id, options); - } - - public Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null) - { - return _invoiceService.VoidInvoiceAsync(id, options); - } - - public IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options) - { - return _paymentMethodService.ListAutoPaging(options); - } - - public IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options) - => _paymentMethodService.ListAutoPagingAsync(options); - - public Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null) - { - return _paymentMethodService.AttachAsync(id, options); - } - - public Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null) - { - return _paymentMethodService.DetachAsync(id, options); - } - - public Task PlanGetAsync(string id, PlanGetOptions options = null) - { - return _planService.GetAsync(id, options); - } - - public Task TaxIdCreateAsync(string id, TaxIdCreateOptions options) - { - return _taxIdService.CreateAsync(id, options); - } - - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - TaxIdDeleteOptions options = null) - { - return _taxIdService.DeleteAsync(customerId, taxIdId); - } - - public Task> TaxRegistrationsListAsync(RegistrationListOptions options = null) - { - return _taxRegistrationService.ListAsync(options); - } - - public Task> ChargeListAsync(ChargeListOptions options) - { - return _chargeService.ListAsync(options); - } - - public Task RefundCreateAsync(RefundCreateOptions options) - { - return _refundService.CreateAsync(options); - } - - public Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null) - { - return _cardService.DeleteAsync(customerId, cardId, options); - } - - public Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null) - { - return _bankAccountService.CreateAsync(customerId, options); - } - - public Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) - { - return _bankAccountService.DeleteAsync(customerId, bankAccount, options); - } - - public async Task> PriceListAsync(PriceListOptions options = null) - { - return await _priceService.ListAsync(options); - } - - public Task SetupIntentCreate(SetupIntentCreateOptions options) - => _setupIntentService.CreateAsync(options); - - public async Task> SetupIntentList(SetupIntentListOptions options) - { - var setupIntents = await _setupIntentService.ListAsync(options); - - return setupIntents.Data; - } - - public Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null) - => _setupIntentService.CancelAsync(id, options); - - public Task SetupIntentGet(string id, SetupIntentGetOptions options = null) - => _setupIntentService.GetAsync(id, options); - - public Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options) - => _setupIntentService.VerifyMicrodepositsAsync(id, options); - - public async Task> TestClockListAsync() - { - var items = new List(); - var options = new Stripe.TestHelpers.TestClockListOptions() - { - Limit = 100 - }; - await foreach (var i in _testClockService.ListAutoPagingAsync(options)) - { - items.Add(i); - } - return items; - } - - public Task PriceGetAsync(string id, PriceGetOptions options = null) - => _priceService.GetAsync(id, options); -} diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 2d2a9f0ae7..fbc382cb08 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -57,7 +57,7 @@ public class UserService : UserManager, IUserService private readonly ILicensingService _licenseService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly IFido2 _fido2; @@ -93,7 +93,7 @@ public class UserService : UserManager, IUserService ILicensingService licenseService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, IFido2 fido2, @@ -534,7 +534,7 @@ public class UserService : UserManager, IUserService try { - await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, + await _stripeSyncService.UpdateCustomerEmailAddressAsync(user.GatewayCustomerId, user.BillingEmailAddress()); } catch (Exception ex) @@ -867,7 +867,7 @@ public class UserService : UserManager, IUserService } string paymentIntentClientSecret = null; - IPaymentService paymentService = null; + IStripePaymentService paymentService = null; if (_globalSettings.SelfHosted) { if (license == null || !_licenseService.VerifyLicense(license)) diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index b0d7da05a2..ddc48521e3 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -56,7 +56,7 @@ public class GlobalSettings : IGlobalSettings public virtual EventLoggingSettings EventLogging { get; set; } = new EventLoggingSettings(); public virtual MailSettings Mail { get; set; } = new MailSettings(); public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); - public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); + public virtual AzureQueueEventSettings Events { get; set; } = new AzureQueueEventSettings(); public virtual DistributedCacheSettings DistributedCache { get; set; } = new DistributedCacheSettings(); public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); public virtual IFileStorageSettings Attachment { get; set; } @@ -395,6 +395,24 @@ public class GlobalSettings : IGlobalSettings } } + public class AzureQueueEventSettings : IConnectionStringSettings + { + private string _connectionString; + private string _queueName; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value?.Trim('"'); + } + + public string QueueName + { + get => _queueName; + set => _queueName = value?.Trim('"'); + } + } + public class ConnectionStringSettings : IConnectionStringSettings { private string _connectionString; diff --git a/src/Core/Utilities/BillingHelpers.cs b/src/Core/Utilities/BillingHelpers.cs index 2c1dfcbbbd..ef0fdf010b 100644 --- a/src/Core/Utilities/BillingHelpers.cs +++ b/src/Core/Utilities/BillingHelpers.cs @@ -1,12 +1,12 @@ -using Bit.Core.Entities; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; using Bit.Core.Exceptions; -using Bit.Core.Services; namespace Bit.Core.Utilities; public static class BillingHelpers { - internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, + internal static async Task AdjustStorageAsync(IStripePaymentService paymentService, IStorableSubscriber storableSubscriber, short storageAdjustmentGb, string storagePlanId, short baseStorageGb) { if (storableSubscriber == null) diff --git a/src/Core/Utilities/CACHING.md b/src/Core/Utilities/CACHING.md index d80e629bdd..c29a14d751 100644 --- a/src/Core/Utilities/CACHING.md +++ b/src/Core/Utilities/CACHING.md @@ -381,7 +381,7 @@ public class OrganizationAbilityService ### Example Usage: Default (Ephemeral Data) -#### 1. Registration (already done in Api, Admin, Billing, Identity, and Notifications Startup.cs files, plus Events and EventsProcessor service collection extensions): +#### 1. Registration (already done in Api, Admin, Billing, Events, EventsProcessor, Identity, and Notifications Startup.cs files): ```csharp services.AddDistributedCache(globalSettings); diff --git a/src/Core/Utilities/EventIntegrationsCacheConstants.cs b/src/Core/Utilities/EventIntegrationsCacheConstants.cs index 6bd90c797d..19cc3f949c 100644 --- a/src/Core/Utilities/EventIntegrationsCacheConstants.cs +++ b/src/Core/Utilities/EventIntegrationsCacheConstants.cs @@ -55,16 +55,16 @@ public static class EventIntegrationsCacheConstants /// Builds a deterministic cache key for an organization's integration configuration details /// . /// - /// The unique identifier of the organization to which the user belongs. + /// The unique identifier of the organization. /// The of the integration. - /// The of the event configured. Can be null to apply to all events. + /// The specific of the event configured. /// /// A cache key for the configuration details. /// public static string BuildCacheKeyForOrganizationIntegrationConfigurationDetails( Guid organizationId, IntegrationType integrationType, - EventType? eventType + EventType eventType ) => $"OrganizationIntegrationConfigurationDetails:{organizationId:N}:{integrationType}:{eventType}"; /// diff --git a/src/Events/Startup.cs b/src/Events/Startup.cs index f67debd092..75301cf08c 100644 --- a/src/Events/Startup.cs +++ b/src/Events/Startup.cs @@ -84,6 +84,8 @@ public class Startup services.AddHostedService(); } + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddRabbitMqListeners(globalSettings); } diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index 8dc0f12c0c..c4c02e32d2 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -6,6 +6,7 @@ using Azure.Storage.Queues; using Bit.Core; using Bit.Core.Models.Data; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; namespace Bit.EventsProcessor; @@ -13,7 +14,7 @@ namespace Bit.EventsProcessor; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IConfiguration _configuration; + private readonly GlobalSettings _globalSettings; private Task _executingTask; private CancellationTokenSource _cts; @@ -22,10 +23,10 @@ public class AzureQueueHostedService : IHostedService, IDisposable public AzureQueueHostedService( ILogger logger, - IConfiguration configuration) + GlobalSettings globalSettings) { _logger = logger; - _configuration = configuration; + _globalSettings = globalSettings; } public Task StartAsync(CancellationToken cancellationToken) @@ -56,11 +57,12 @@ public class AzureQueueHostedService : IHostedService, IDisposable private async Task ExecuteAsync(CancellationToken cancellationToken) { - var storageConnectionString = _configuration["azureStorageConnectionString"]; - var queueName = _configuration["azureQueueServiceQueueName"]; + var storageConnectionString = _globalSettings.Events.ConnectionString; + var queueName = _globalSettings.Events.QueueName; if (string.IsNullOrWhiteSpace(storageConnectionString) || string.IsNullOrWhiteSpace(queueName)) { + _logger.LogInformation("Azure Queue Hosted Service is disabled. Missing connection string or queue name."); return; } diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index 260c501e01..888dda43a1 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -31,7 +31,8 @@ public class Startup // Repositories services.AddDatabaseRepositories(globalSettings); - // Hosted Services + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddAzureServiceBusListeners(globalSettings); services.AddHostedService(); } diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index 66121a783c..7d807d432b 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -100,167 +100,16 @@ public abstract class BaseRequestValidator where T : class protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - if (_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) + var validators = DetermineValidationOrder(context, request, validatorContext); + var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); + if (!allValidationSchemesSuccessful) { - var validators = DetermineValidationOrder(context, request, validatorContext); - var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); - if (!allValidationSchemesSuccessful) - { - // Each validation task is responsible for setting its own non-success status, if applicable. - return; - } - await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, - validatorContext.RememberMeRequested); + // Each validation task is responsible for setting its own non-success status, if applicable. + return; } - else - { - // 1. We need to check if the user is legitimate via the contextually appropriate mechanism - // (webauthn, password, custom token, etc.). - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) - { - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); - return; - } - - // 1.5 Now check the version number of the client. Do this after ValidateContextAsync so that - // we prevent account enumeration. If we were to do this before ValidateContextAsync, then attackers - // could use a known invalid client version and make a request for a user (before we know if they have - // demonstrated ownership of the account via correct credentials) and identify if they exist by getting - // an error response back from the validator saying the user is not compatible with the client. - var clientVersionValid = await ValidateClientVersionAsync(context, validatorContext); - if (!clientVersionValid) - { - return; - } - - // 2. Decide if this user belongs to an organization that requires SSO. - // TODO: Clean up Feature Flag: Remove this if block: PM-28281 - if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) - { - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) - { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; - } - } - else - { - var ssoValid = await _ssoRequestValidator.ValidateAsync(user, request, validatorContext); - if (!ssoValid) - { - // SSO is required - SetValidationErrorResult(context, validatorContext); - return; - } - } - - // 3. Check if 2FA is required. - (validatorContext.TwoFactorRequired, var twoFactorOrganization) = - await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); - - // This flag is used to determine if the user wants a rememberMe token sent when - // authentication is successful. - var returnRememberMeToken = false; - - if (validatorContext.TwoFactorRequired) - { - var twoFactorToken = request.Raw["TwoFactorToken"]; - var twoFactorProvider = request.Raw["TwoFactorProvider"]; - var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - // 3a. Response for 2FA required and not provided state. - if (!validTwoFactorRequest || - !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - if (resultDict == null) - { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); - return; - } - - // Include Master Password Policy in 2FA response. - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - return; - } - - var twoFactorTokenValid = - await _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); - - // 3b. Response for 2FA required but request is not valid or remember token expired state. - if (!twoFactorTokenValid) - { - // The remember me token has expired. - if (twoFactorProviderType == TwoFactorProviderType.Remember) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - - // Include Master Password Policy in 2FA response - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - } - else - { - await SendFailedTwoFactorEmail(user, twoFactorProviderType); - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - } - - return; - } - - // 3c. When the 2FA authentication is successful, we can check if the user wants a - // rememberMe token. - var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; - // Check if the user wants a rememberMe token. - if (twoFactorRemember - // if the 2FA auth was rememberMe do not send another token. - && twoFactorProviderType != TwoFactorProviderType.Remember) - { - returnRememberMeToken = true; - } - } - - // 4. Check if the user is logging in from a new device. - var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); - if (!deviceValid) - { - SetValidationErrorResult(context, validatorContext); - await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); - return; - } - - // 5. Force legacy users to the web for migration. - if (UserService.IsLegacyUser(user) && request.ClientId != "web") - { - await FailAuthForLegacyUserAsync(user, context); - return; - } - - // TODO: PM-24324 - This should be its own validator at some point. - // 6. Auth request handling - if (validatorContext.ValidatedAuthRequest != null) - { - validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; - await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); - } - - await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); - } + await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, + validatorContext.RememberMeRequested); } protected async Task FailAuthForLegacyUserAsync(User user, T context) @@ -291,6 +140,11 @@ public abstract class BaseRequestValidator where T : class return [ () => ValidateGrantSpecificContext(context, validatorContext), + // Now check the version number of the client. Do this after ValidateContextAsync so that + // we prevent account enumeration. If we were to do this before ValidateContextAsync, then attackers + // could use a known invalid client version and make a request for a user (before we know if they have + // demonstrated ownership of the account via correct credentials) and identify if they exist by getting + // an error response back from the validator saying the user is not compatible with the client. () => ValidateClientVersionAsync(context, validatorContext), () => ValidateTwoFactorAsync(context, request, validatorContext), () => ValidateSsoAsync(context, request, validatorContext), @@ -305,6 +159,11 @@ public abstract class BaseRequestValidator where T : class return [ () => ValidateGrantSpecificContext(context, validatorContext), + // Now check the version number of the client. Do this after ValidateContextAsync so that + // we prevent account enumeration. If we were to do this before ValidateContextAsync, then attackers + // could use a known invalid client version and make a request for a user (before we know if they have + // demonstrated ownership of the account via correct credentials) and identify if they exist by getting + // an error response back from the validator saying the user is not compatible with the client. () => ValidateClientVersionAsync(context, validatorContext), () => ValidateSsoAsync(context, request, validatorContext), () => ValidateTwoFactorAsync(context, request, validatorContext), @@ -426,17 +285,22 @@ public abstract class BaseRequestValidator where T : class if (validatorContext.TwoFactorRequired && validatorContext.TwoFactorRecoveryRequested) { - SetSsoResult(context, new Dictionary - { - { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } - }); + SetSsoResult(context, + new Dictionary + { + { + "ErrorModel", + new ErrorResponseModel( + "Two-factor recovery has been performed. SSO authentication is required.") + } + }); return false; } SetSsoResult(context, new Dictionary { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } }); return false; } @@ -717,7 +581,8 @@ public abstract class BaseRequestValidator where T : class /// user trying to login /// magic string identifying the grant type requested /// true if sso required; false if not required or already in process - [Obsolete("This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] + [Obsolete( + "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] private async Task RequireSsoLoginAsync(User user, string grantType) { if (grantType == "authorization_code" || grantType == "client_credentials") diff --git a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs index 81f8ba1c3f..145ecc8737 100644 --- a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs @@ -48,8 +48,6 @@ public class SsoRequestValidator( // evaluated, and recovery will have been performed if requested. // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect // to /login. - // If the feature flag RecoveryCodeSupportForSsoRequiredUsers is set to false then this code is unreachable since - // Two Factor validation occurs after SSO validation in that scenario. if (context.TwoFactorRequired && context.TwoFactorRecoveryRequested) { await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription); @@ -63,10 +61,10 @@ public class SsoRequestValidator( /// /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. - /// If the GrantType is authorization_code or client_credentials we know the user is trying to login + /// If the GrantType is authorization_code or client_credentials we know the user is trying to log in /// using the SSO flow so they are allowed to continue. /// - /// user trying to login + /// user trying to log in /// magic string identifying the grant type requested /// true if sso required; false if not required or already in process private async Task RequireSsoAuthenticationAsync(User user, string grantType) diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index e300921ac9..167595bf89 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -6,13 +6,9 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; -using Azure.Messaging.ServiceBus; using Bit.Core.AdminConsole.AbilitiesCache; using Bit.Core.AdminConsole.Models.Business.Tokenables; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.AdminConsole.Models.Teams; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; using Bit.Core.AdminConsole.Services.NoopImplementations; @@ -73,8 +69,6 @@ using Microsoft.AspNetCore.HttpOverrides; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc.Localization; using Microsoft.Azure.Cosmos.Fluent; -using Microsoft.Bot.Builder; -using Microsoft.Bot.Builder.Integration.AspNet.Core; using Microsoft.Extensions.Caching.Cosmos; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Configuration; @@ -86,7 +80,6 @@ using Microsoft.Extensions.Options; using Microsoft.OpenApi.Models; using StackExchange.Redis; using Swashbuckle.AspNetCore.SwaggerGen; -using ZiggyCreatures.Caching.Fusion; using Constants = Bit.Core.Constants; using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; @@ -245,7 +238,7 @@ public static class ServiceCollectionExtensions PrivateKey = globalSettings.Braintree.PrivateKey }; }); - services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); // Legacy mailer service @@ -525,116 +518,6 @@ public static class ServiceCollectionExtensions return globalSettings; } - public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) - { - if (IsAzureServiceBusEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddSingleton(); - return services; - } - - if (IsRabbitMqEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddSingleton(); - return services; - } - - if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.TryAddSingleton(); - return services; - } - - if (globalSettings.SelfHosted) - { - services.TryAddSingleton(); - return services; - } - - services.TryAddSingleton(); - return services; - } - - public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsAzureServiceBusEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsRabbitMqEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && - CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && - CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) - { - services.AddHttpClient(SlackService.HttpClientName); - services.TryAddSingleton(); - } - else - { - services.TryAddSingleton(); - } - - return services; - } - - public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && - CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && - CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) - { - services.AddHttpClient(TeamsService.HttpClientName); - services.TryAddSingleton(); - services.TryAddSingleton(sp => sp.GetRequiredService()); - services.TryAddSingleton(sp => sp.GetRequiredService()); - services.TryAddSingleton(sp => - new BotFrameworkHttpAdapter( - new TeamsBotCredentialProvider( - clientId: globalSettings.Teams.ClientId, - clientSecret: globalSettings.Teams.ClientSecret - ) - ) - ); - } - else - { - services.TryAddSingleton(); - } - - return services; - } - public static void UseDefaultMiddleware(this IApplicationBuilder app, IWebHostEnvironment env, GlobalSettings globalSettings) { @@ -881,186 +764,6 @@ public static class ServiceCollectionExtensions return (provider, connectionString); } - private static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration - { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), - configurationRepository: provider.GetRequiredService(), - groupRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.EventPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusIntegrationListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredService>(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, - GlobalSettings globalSettings) - { - // Add common services - services.AddDistributedCache(globalSettings); - services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - - // Add services in support of handlers - services.AddSlackService(globalSettings); - services.AddTeamsService(globalSettings); - services.TryAddSingleton(TimeProvider.System); - services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); - services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); - - // Add integration handlers - services.TryAddSingleton, SlackIntegrationHandler>(); - services.TryAddSingleton, WebhookIntegrationHandler>(); - services.TryAddSingleton, DatadogIntegrationHandler>(); - services.TryAddSingleton, TeamsIntegrationHandler>(); - - var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); - var slackConfiguration = new SlackListenerConfiguration(globalSettings); - var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); - var hecConfiguration = new HecListenerConfiguration(globalSettings); - var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); - var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); - - if (IsRabbitMqEnabled(globalSettings)) - { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredService(), - configuration: repositoryConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddRabbitMqIntegration(slackConfiguration); - services.AddRabbitMqIntegration(webhookConfiguration); - services.AddRabbitMqIntegration(hecConfiguration); - services.AddRabbitMqIntegration(datadogConfiguration); - services.AddRabbitMqIntegration(teamsConfiguration); - } - - if (IsAzureServiceBusEnabled(globalSettings)) - { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: repositoryConfiguration, - handler: provider.GetRequiredService(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = repositoryConfiguration.EventPrefetchCount, - MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddAzureServiceBusIntegration(slackConfiguration); - services.AddAzureServiceBusIntegration(webhookConfiguration); - services.AddAzureServiceBusIntegration(hecConfiguration); - services.AddAzureServiceBusIntegration(datadogConfiguration); - services.AddAzureServiceBusIntegration(teamsConfiguration); - } - - return services; - } - - private static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration - { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), - configurationRepository: provider.GetRequiredService(), - groupRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqIntegrationListenerService( - handler: provider.GetRequiredService>(), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService(), - timeProvider: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static bool IsAzureServiceBusEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName); - } - - private static bool IsRabbitMqEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName); - } - /// /// Adds a server with its corresponding OAuth2 client credentials security definition and requirement. /// diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs index 9ab626d3f0..6e1dadb92f 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs @@ -1,18 +1,14 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Controllers; +using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; using Bit.Core.Context; -using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; using NSubstitute; -using NSubstitute.ReturnsExtensions; using Xunit; namespace Bit.Api.Test.AdminConsole.Controllers; @@ -25,823 +21,191 @@ public class OrganizationIntegrationsConfigurationControllerTests public async Task DeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) + Guid integrationId, + Guid configurationId) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegrationConfiguration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegrationConfiguration); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); } [Theory, BitAutoData] + [Obsolete("Obsolete")] public async Task PostDeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegrationConfiguration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegrationConfiguration); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } + await sutProvider.Sut.PostDeleteAsync(organizationId, integrationId, configurationId); - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigDoesNotBelongToIntegration_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = Guid.Empty; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); } [Theory, BitAutoData] public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); } [Theory, BitAutoData] public async Task GetAsync_ConfigurationsExist_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - List organizationIntegrationConfigurations) + Guid integrationId, + List configurations) { - organizationIntegration.OrganizationId = organizationId; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) - .Returns(organizationIntegrationConfigurations); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); Assert.NotNull(result); - Assert.Equal(organizationIntegrationConfigurations.Count, result.Count); + Assert.Equal(configurations.Count, result.Count); Assert.All(result, r => Assert.IsType(r)); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); } [Theory, BitAutoData] public async Task GetAsync_NoConfigurationsExist_ReturnsEmptyList( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration) + Guid integrationId) { - organizationIntegration.OrganizationId = organizationId; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) .Returns([]); - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); + Assert.NotNull(result); Assert.Empty(result); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id)); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); } [Theory, BitAutoData] public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.GetAsync(organizationId, integrationId)); } [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Slack_Succeeds( + public async Task PostAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, + Guid integrationId, + OrganizationIntegrationConfiguration configuration, OrganizationIntegrationConfigurationRequestModel model) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); + sutProvider.GetDependency() + .CreateAsync(organizationId, integrationId, Arg.Any()) + .Returns(configuration); - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, integrationId, model); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(organizationId, integrationId, Arg.Any()); Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); } [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Webhook_Succeeds( + public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeCloudBillingSync_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.CloudBillingSync; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeScim_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Scim; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, Guid.Empty, new OrganizationIntegrationConfigurationRequestModel())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(organizationId, integrationId, new OrganizationIntegrationConfigurationRequestModel())); } [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Slack_Succeeds( + public async Task UpdateAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration configuration, OrganizationIntegrationConfigurationRequestModel model) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); + sutProvider.GetDependency() + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()) + .Returns(configuration); - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); + var updateResponse = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, model); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()); Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); } - [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Webhook_Succeeds( + public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, new OrganizationIntegrationConfigurationRequestModel())); } } diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs deleted file mode 100644 index 8a75db9da8..0000000000 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs +++ /dev/null @@ -1,248 +0,0 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations; - -public class OrganizationIntegrationConfigurationRequestModelTests -{ - [Fact] - public void IsValidForType_CloudBillingSyncIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.CloudBillingSync)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Theory] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Fact] - public void IsValidForType_NullConfiguration_ReturnsTrue() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = null, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.True(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.True(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyTemplate_ReturnsFalse(string? template) - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = template - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Fact] - public void IsValidForType_InvalidJsonConfiguration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{not valid json}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - - [Fact] - public void IsValidForType_InvalidJsonFilters_ReturnsFalse() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = "{Not valid json", - Template = "template" - }; - - Assert.False(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ScimIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Scim)); - } - - [Fact] - public void IsValidForType_ValidSlackConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new SlackIntegrationConfiguration(ChannelId: "C12345")); - - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidSlackConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidNoAuthWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( - Uri: new Uri("https://example.com"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_UnknownIntegrationType_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - var unknownType = (IntegrationType)999; - - Assert.False(condition: model.IsValidForType(unknownType)); - } -} diff --git a/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs index 0309264096..16b9b26436 100644 --- a/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs @@ -28,7 +28,7 @@ public class AccountsControllerTests : IDisposable private readonly IUserService _userService; private readonly IFeatureService _featureService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ILicensingService _licensingService; @@ -39,7 +39,7 @@ public class AccountsControllerTests : IDisposable { _userService = Substitute.For(); _featureService = Substitute.For(); - _paymentService = Substitute.For(); + _paymentService = Substitute.For(); _twoFactorIsEnabledQuery = Substitute.For(); _userAccountKeysQuery = Substitute.For(); _licensingService = Substitute.For(); diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs index d79bfde893..ee0bdc61e4 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs @@ -3,9 +3,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http.HttpResults; @@ -103,7 +103,7 @@ public class OrganizationBillingControllerTests // Manually create a BillingHistoryInfo object to avoid requiring AutoFixture to create HttpResponseHeaders var billingInfo = new BillingHistoryInfo(); - sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); + sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); // Act var result = await sutProvider.Sut.GetHistoryAsync(organizationId); diff --git a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs index a776bbea22..9a3f57c3dc 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs @@ -37,7 +37,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationService _organizationService; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IUserService _userService; private readonly IGetCloudOrganizationLicenseQuery _getCloudOrganizationLicenseQuery; @@ -59,7 +59,7 @@ public class OrganizationsControllerTests : IDisposable _organizationRepository = Substitute.For(); _organizationService = Substitute.For(); _organizationUserRepository = Substitute.For(); - _paymentService = Substitute.For(); + _paymentService = Substitute.For(); Substitute.For(); _ssoConfigRepository = Substitute.For(); Substitute.For(); diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index b7349c09d9..652e82c801 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -10,10 +10,10 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Models.Api; using Bit.Core.Models.BitStripe; -using Bit.Core.Services; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -121,7 +121,7 @@ public class ProviderBillingControllerTests } }; - sutProvider.GetDependency().InvoiceListAsync(Arg.Is( + sutProvider.GetDependency().ListInvoicesAsync(Arg.Is( options => options.Customer == provider.GatewayCustomerId)).Returns(invoices); @@ -301,7 +301,7 @@ public class ProviderBillingControllerTests Status = "unpaid" }; - stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer.tax_ids") && options.Expand.Contains("discounts") && @@ -318,7 +318,7 @@ public class ProviderBillingControllerTests Attempted = true }; - stripeAdapter.InvoiceSearchAsync(Arg.Is( + stripeAdapter.SearchInvoiceAsync(Arg.Is( options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) .Returns([overdueInvoice]); @@ -351,7 +351,7 @@ public class ProviderBillingControllerTests var plan = MockPlans.Get(providerPlan.PlanType); sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); - sutProvider.GetDependency().PriceGetAsync(priceId) + sutProvider.GetDependency().GetPriceAsync(priceId) .Returns(new Price { UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 @@ -459,13 +459,13 @@ public class ProviderBillingControllerTests Status = "active" }; - stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer.tax_ids") && options.Expand.Contains("discounts") && options.Expand.Contains("test_clock"))).Returns(subscription); - stripeAdapter.InvoiceSearchAsync(Arg.Is( + stripeAdapter.SearchInvoiceAsync(Arg.Is( options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) .Returns([]); @@ -498,7 +498,7 @@ public class ProviderBillingControllerTests var plan = MockPlans.Get(providerPlan.PlanType); sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); - sutProvider.GetDependency().PriceGetAsync(priceId) + sutProvider.GetDependency().GetPriceAsync(priceId) .Returns(new Price { UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 diff --git a/test/Billing.Test/Controllers/BitPayControllerTests.cs b/test/Billing.Test/Controllers/BitPayControllerTests.cs index d2d1c5b571..0118009cb7 100644 --- a/test/Billing.Test/Controllers/BitPayControllerTests.cs +++ b/test/Billing.Test/Controllers/BitPayControllerTests.cs @@ -31,7 +31,7 @@ public class BitPayControllerTests private readonly IUserRepository _userRepository = Substitute.For(); private readonly IProviderRepository _providerRepository = Substitute.For(); private readonly IMailService _mailService = Substitute.For(); - private readonly IPaymentService _paymentService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); private readonly IPremiumUserBillingService _premiumUserBillingService = Substitute.For(); diff --git a/test/Billing.Test/Controllers/PayPalControllerTests.cs b/test/Billing.Test/Controllers/PayPalControllerTests.cs index f52a304bb6..da995b6188 100644 --- a/test/Billing.Test/Controllers/PayPalControllerTests.cs +++ b/test/Billing.Test/Controllers/PayPalControllerTests.cs @@ -28,7 +28,7 @@ public class PayPalControllerTests(ITestOutputHelper testOutputHelper) private readonly IOptions _billingSettings = Substitute.For>(); private readonly IMailService _mailService = Substitute.For(); private readonly IOrganizationRepository _organizationRepository = Substitute.For(); - private readonly IPaymentService _paymentService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); private readonly ITransactionRepository _transactionRepository = Substitute.For(); private readonly IUserRepository _userRepository = Substitute.For(); private readonly IProviderRepository _providerRepository = Substitute.For(); diff --git a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs index e9f0d9d0ed..a7aefe3163 100644 --- a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs +++ b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs @@ -4,8 +4,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -61,7 +61,7 @@ public class SetupIntentSucceededHandlerTests // Assert await _setupIntentCache.DidNotReceiveWithAnyArgs().GetSubscriberIdForSetupIntent(Arg.Any()); - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -86,7 +86,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -116,7 +116,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == organization.GatewayCustomerId)); @@ -151,7 +151,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == provider.GatewayCustomerId)); @@ -183,7 +183,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -216,7 +216,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 4a480f8c30..182f09e163 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -20,6 +20,8 @@ using Quartz; using Stripe; using Xunit; using Event = Stripe.Event; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; +using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; namespace Bit.Billing.Test.Services; @@ -400,6 +402,75 @@ public class SubscriptionUpdatedHandlerTests var parsedEvent = new Event { Data = new EventData() }; + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List() }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + } + + [Fact] + public async Task HandleAsync_IncompleteExpiredUserSubscription_DisablesPremiumAndCancelsSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.IncompleteExpired, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); @@ -565,7 +636,7 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), - Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + Plan = new Stripe.Plan { Id = "2023-enterprise-org-seat-annually" } } ] }, @@ -599,7 +670,7 @@ public class SubscriptionUpdatedHandlerTests { Data = [ - new SubscriptionItem { Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } } + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-enterprise-seat-annually" } } ] } }) @@ -933,6 +1004,134 @@ public class SubscriptionUpdatedHandlerTests return (providerId, newSubscription, provider, parsedEvent); } + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithOpenInvoice_CancelsSubscriptionAndDisablesPremium() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var openInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Open + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = openInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List { openInvoice } }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + await _stripeFacade.Received(1) + .VoidInvoice(openInvoice.Id); + } + + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithoutOpenInvoice_DoesNotCancelSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var paidInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Paid + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = paidInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.DidNotReceive() + .DisablePremiumAsync(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .CancelSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .ListInvoices(Arg.Any()); + } + public static IEnumerable GetNonActiveSubscriptions() { return new List diff --git a/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs index f69a61a322..08fcd23969 100644 --- a/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs +++ b/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs @@ -1,10 +1,19 @@ -using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.Services.NoopImplementations; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; using NSubstitute; using StackExchange.Redis; using Xunit; @@ -32,6 +41,7 @@ public class EventIntegrationServiceCollectionExtensionsTests // Mock required repository dependencies for commands _services.TryAddScoped(_ => Substitute.For()); + _services.TryAddScoped(_ => Substitute.For()); _services.TryAddScoped(_ => Substitute.For()); } @@ -45,6 +55,9 @@ public class EventIntegrationServiceCollectionExtensionsTests var cache = provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName); Assert.NotNull(cache); + var validator = provider.GetRequiredService(); + Assert.NotNull(validator); + using var scope = provider.CreateScope(); var sp = scope.ServiceProvider; @@ -52,6 +65,11 @@ public class EventIntegrationServiceCollectionExtensionsTests Assert.NotNull(sp.GetService()); Assert.NotNull(sp.GetService()); Assert.NotNull(sp.GetService()); + + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); } [Fact] @@ -61,8 +79,11 @@ public class EventIntegrationServiceCollectionExtensionsTests var createIntegrationDescriptor = _services.First(s => s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)); + var createConfigDescriptor = _services.First(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); Assert.Equal(ServiceLifetime.Scoped, createIntegrationDescriptor.Lifetime); + Assert.Equal(ServiceLifetime.Scoped, createConfigDescriptor.Lifetime); } [Fact] @@ -117,7 +138,7 @@ public class EventIntegrationServiceCollectionExtensionsTests _services.AddEventIntegrationsCommandsQueries(_globalSettings); var createConfigCmdDescriptors = _services.Where(s => - s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)).ToList(); + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); Assert.Single(createConfigCmdDescriptors); var updateIntegrationCmdDescriptors = _services.Where(s => @@ -148,6 +169,690 @@ public class EventIntegrationServiceCollectionExtensionsTests Assert.Single(createCmdDescriptors); } + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_RegistersAllConfigurationServices() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + Assert.Contains(_services, s => s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IUpdateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IDeleteOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IGetOrganizationIntegrationConfigurationsQuery)); + } + + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + var createCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); + Assert.Single(createCmdDescriptors); + } + + [Fact] + public void IsRabbitMqEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingHostName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingUsername_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingPassword_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = null, + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingExchangeName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingConnectionString_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = null, + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingTopicName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void AddSlackService_AllSettingsPresent_RegistersSlackService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = "test-client-id", + ["GlobalSettings:Slack:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Slack:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddSlackService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = null, + ["GlobalSettings:Slack:ClientSecret"] = null, + ["GlobalSettings:Slack:Scopes"] = null + }); + + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + } + + [Fact] + public void AddTeamsService_AllSettingsPresent_RegistersTeamsServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = "test-client-id", + ["GlobalSettings:Teams:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Teams:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.TryAddScoped(_ => Substitute.For()); + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + + var teamsService = provider.GetService(); + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + + var bot = provider.GetService(); + Assert.NotNull(bot); + Assert.IsType(bot); + + var adapter = provider.GetService(); + Assert.NotNull(adapter); + Assert.IsType(adapter); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddTeamsService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = null, + ["GlobalSettings:Teams:ClientSecret"] = null, + ["GlobalSettings:Teams:Scopes"] = null + }); + + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + var teamsService = provider.GetService(); + + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.TryAddSingleton(TimeProvider.System); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_RegistersCommonServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify common services are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationFilterService)); + Assert.Contains(services, s => s.ServiceType == typeof(TimeProvider)); + + // Verify HttpClients for handlers are registered + var httpClientDescriptors = services.Where(s => s.ServiceType == typeof(IHttpClientFactory)).ToList(); + Assert.NotEmpty(httpClientDescriptors); + } + + [Fact] + public void AddEventIntegrationServices_RegistersIntegrationHandlers() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify integration handlers are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + } + + [Fact] + public void AddEventIntegrationServices_RabbitMqEnabled_RegistersRabbitMqListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for RabbitMQ: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_AzureServiceBusEnabled_RegistersAzureListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_BothEnabled_AzureServiceBusTakesPrecedence() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + // NO RabbitMQ services should be enabled because ASB takes precedence + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_NeitherEnabled_RegistersNoListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register no hosted services when neither RabbitMQ nor Azure Service Bus is enabled + Assert.Equal(0, afterCount - beforeCount); + } + + [Fact] + public void AddEventWriteServices_AzureServiceBusEnabled_RegistersAzureServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_EventsConnectionStringPresent_RegistersAzureQueueService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Events:ConnectionString"] = "DefaultEndpointsProtocol=https;AccountName=test;AccountKey=test;EndpointSuffix=core.windows.net", + ["GlobalSettings:Events:QueueName"] = "event" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(AzureQueueEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_SelfHosted_RegistersRepositoryService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:SelfHosted"] = "true" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(RepositoryEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_NothingEnabled_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(NoopEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_AzureTakesPrecedenceOverRabbitMq() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + services.AddEventWriteServices(globalSettings); + + // Should use Azure Service Bus, not RabbitMQ + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.DoesNotContain(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusEnabled_RegistersAzureServiceBusServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IAzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventRepository)); + Assert.Contains(services, s => s.ServiceType == typeof(AzureTableStorageEventHandler)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddAzureServiceBusListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IRabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(EventRepositoryHandler)); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddRabbitMqListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + private static GlobalSettings CreateGlobalSettings(Dictionary data) { var config = new ConfigurationBuilder() diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..c6c8a44955 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,179 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class CreateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task CreateAsync_Success_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_WildcardSuccess_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..3b12f4bd88 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,211 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class DeleteOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_Success_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_WildcardSuccess_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs new file mode 100644 index 0000000000..18541df53e --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs @@ -0,0 +1,101 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class GetOrganizationIntegrationConfigurationsQueryTests +{ + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_Success_ReturnsConfigurations( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + List configurations) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(integrationId); + Assert.Equal(configurations.Count, result.Count); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_NoConfigurations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns([]); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..c2eeefc087 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,390 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class UpdateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_Success_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WildcardSuccess_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedEventType_UpdatesConfigurationAndInvalidatesCacheForBothTypes( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = EventType.Cipher_Created; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + updatedConfiguration.EventType.Value)); + // Verify RemoveByTagAsync was NOT called since both are specific event types + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromWildcardToSpecific_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; // Wildcard + updatedConfiguration.EventType = EventType.User_LoggedIn; // Specific + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromSpecificToWildcard_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; // Specific + updatedConfiguration.EventType = null; // Wildcard + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs index 916fe981de..50442dd463 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs +++ b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs @@ -17,4 +17,5 @@ public class TestListenerConfiguration : IIntegrationListenerConfiguration public int EventPrefetchCount => 0; public int IntegrationMaxConcurrentCalls => 1; public int IntegrationPrefetchCount => 0; + public string RoutingKey => IntegrationType.ToRoutingKey(); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs index 933bcbc3a1..efcd57b6ad 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.Import; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -57,7 +58,7 @@ public class ImportOrganizationUsersAndGroupsCommandTests var organizationUserRepository = sutProvider.GetDependency(); SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository); - sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id).Returns(existingUsers); sutProvider.GetDependency().GetOccupiedSeatCountByOrganizationIdAsync(org.Id).Returns( new OrganizationSeatCounts diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs index 04ef3961ca..e26d9ce978 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs @@ -3,11 +3,11 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -50,7 +50,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -96,7 +96,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -140,7 +140,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 4 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs index 55e5698ad4..69f69b1d02 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -172,7 +173,7 @@ public class ResellerClientOrganizationSignUpCommandTests private static async Task AssertCleanupIsPerformed(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .CancelAndRecoverChargesAsync(Arg.Any()); await sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs index f9fc086873..47872cc6ab 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs @@ -2,9 +2,9 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Models.StaticStore; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -28,7 +28,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() .AdjustSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); @@ -53,7 +53,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), @@ -81,7 +81,7 @@ public class UpdateOrganizationSubscriptionCommandTests OrganizationSubscriptionUpdate[] subscriptionsToUpdate = [new() { Organization = organization, Plan = new Enterprise2023Plan(true) }]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), Arg.Is(x => x.Type == organization.PlanType), @@ -115,7 +115,7 @@ public class UpdateOrganizationSubscriptionCommandTests new() { Organization = failedOrganization, Plan = new Enterprise2023Plan(true) } ]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == failedOrganization.Id), Arg.Is(x => x.Type == failedOrganization.PlanType), @@ -124,7 +124,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == successfulOrganization.Id), diff --git a/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs new file mode 100644 index 0000000000..1154ad8025 --- /dev/null +++ b/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs @@ -0,0 +1,244 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.Services; + +public class OrganizationIntegrationConfigurationValidatorTests +{ + private readonly OrganizationIntegrationConfigurationValidator _sut = new(); + + [Fact] + public void ValidateConfiguration_CloudBillingSyncIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.CloudBillingSync, configuration)); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyTemplate_ReturnsFalse(string? template) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config2)); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyNonNullConfiguration_ReturnsFalse(string? config) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_NullConfiguration_ReturnsTrue() + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonConfiguration_ReturnsFalse() + { + var config = new OrganizationIntegrationConfiguration + { + Configuration = "{not valid json}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonFilters_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = "template", + Filters = "{Not valid json}" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ScimIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Scim, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidNoAuthWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://localhost"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://example.com"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_UnknownIntegrationType_ReturnsFalse() + { + var unknownType = (IntegrationType)999; + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(unknownType, configuration)); + } +} diff --git a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs index 821ce78074..43a33cda31 100644 --- a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs @@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -1142,7 +1143,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .Received(1) - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(id => id == organization.GatewayCustomerId), Arg.Is(options => options.Email == requestOptionsReturned.Email && options.Description == requestOptionsReturned.Description @@ -1182,7 +1183,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); await organizationRepository .Received(1) .ReplaceAsync(Arg.Is(org => org == organization)); diff --git a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs index ef2b1512c9..2f278dcd20 100644 --- a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs @@ -4,7 +4,7 @@ using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using NSubstitute; @@ -58,7 +58,7 @@ public class PreviewOrganizationTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -68,7 +68,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for sponsored subscription - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -116,7 +116,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8250 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -126,7 +126,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(82.50m, total); // Verify the correct Stripe API call for standalone secrets manager - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -179,7 +179,7 @@ public class PreviewOrganizationTaxCommandTests Total = 12200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -189,7 +189,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(122.00m, total); // Verify the correct Stripe API call for comprehensive purchase with storage and service accounts - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -240,7 +240,7 @@ public class PreviewOrganizationTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -250,7 +250,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify the correct Stripe API call for Families tier (non-seat-based plan) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -292,7 +292,7 @@ public class PreviewOrganizationTaxCommandTests Total = 2700 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -302,7 +302,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(27.00m, total); // Verify the correct Stripe API call for business use in non-US country (tax exempt reverse) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -345,7 +345,7 @@ public class PreviewOrganizationTaxCommandTests Total = 12100 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -355,7 +355,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(121.00m, total); // Verify the correct Stripe API call for Spanish NIF that adds both Spanish NIF and EU VAT tax IDs - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -405,7 +405,7 @@ public class PreviewOrganizationTaxCommandTests Total = 1320 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -415,7 +415,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(13.20m, total); // Verify the correct Stripe API call for free organization upgrade to Teams - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -458,7 +458,7 @@ public class PreviewOrganizationTaxCommandTests Total = 4400 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -468,7 +468,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(44.00m, total); // Verify the correct Stripe API call for free organization upgrade to Families (no SM for Families) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -522,7 +522,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -534,7 +534,7 @@ public class PreviewOrganizationTaxCommandTests Total = 9900 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -543,7 +543,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(9.00m, tax); Assert.Equal(99.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -597,7 +597,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -609,7 +609,7 @@ public class PreviewOrganizationTaxCommandTests Total = 13200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -618,7 +618,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(12.00m, tax); Assert.Equal(132.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -661,7 +661,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -671,7 +671,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for free organization with SM to Enterprise - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -730,7 +730,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -738,7 +738,7 @@ public class PreviewOrganizationTaxCommandTests Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -748,7 +748,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for existing subscription upgrade - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -814,7 +814,7 @@ public class PreviewOrganizationTaxCommandTests } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -822,7 +822,7 @@ public class PreviewOrganizationTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -832,7 +832,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call preserves existing discount - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -876,8 +876,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -919,7 +919,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -927,7 +927,7 @@ public class PreviewOrganizationTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -937,7 +937,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call for PM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -984,7 +984,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -992,7 +992,7 @@ public class PreviewOrganizationTaxCommandTests Total = 13200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1002,7 +1002,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(132.00m, total); // Verify the correct Stripe API call for PM seats + storage - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -1051,7 +1051,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1059,7 +1059,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1069,7 +1069,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for SM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -1119,7 +1119,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1127,7 +1127,7 @@ public class PreviewOrganizationTaxCommandTests Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1137,7 +1137,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for SM seats + service accounts with tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -1200,7 +1200,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1208,7 +1208,7 @@ public class PreviewOrganizationTaxCommandTests Total = 27500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1218,7 +1218,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(275.00m, total); // Verify the correct Stripe API call for comprehensive update with discount and Spanish tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -1276,7 +1276,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1284,7 +1284,7 @@ public class PreviewOrganizationTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1294,7 +1294,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for Families tier (personal usage, no business tax exemption) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -1334,8 +1334,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } [Fact] @@ -1378,7 +1378,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1386,7 +1386,7 @@ public class PreviewOrganizationTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1396,7 +1396,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify only PM seats are included (storage=0 excluded, SM seats=0 so entire SM excluded) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && diff --git a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs index 617a136fab..0ceb257c88 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs @@ -8,7 +8,6 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Platform.Installations; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture; using Bit.Core.Test.Billing.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -59,7 +58,7 @@ public class GetCloudOrganizationLicenseQueryTests { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); @@ -80,7 +79,7 @@ public class GetCloudOrganizationLicenseQueryTests { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); sutProvider.GetDependency() .CreateOrganizationTokenAsync(organization, installationId, subInfo) @@ -119,7 +118,7 @@ public class GetCloudOrganizationLicenseQueryTests installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); - sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs index 05d24bdc34..a7284410fe 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs @@ -8,7 +8,6 @@ using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -382,7 +381,7 @@ public class GetOrganizationWarningsQueryTests var dueDate = now.AddDays(-10); - sutProvider.GetDependency().InvoiceSearchAsync(Arg.Is(options => + sutProvider.GetDependency().SearchInvoiceAsync(Arg.Is(options => options.Query == $"subscription:'{subscriptionId}' status:'open'")).Returns([ new Invoice { DueDate = dueDate, Created = dueDate.AddDays(-30) } ]); @@ -542,7 +541,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -583,7 +582,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -635,7 +634,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -687,7 +686,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -739,7 +738,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -785,7 +784,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs index c42049d5bb..5854d1c3b5 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Microsoft.Extensions.Logging; using NSubstitute; @@ -73,7 +72,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -84,7 +83,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -131,7 +130,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -144,7 +143,7 @@ public class UpdateBillingAddressCommandTests await _subscriberService.Received(1).CreateStripeCustomer(organization); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -192,7 +191,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -204,7 +203,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -260,7 +259,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -272,10 +271,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdDeleteAsync(customer.Id, "tax_id_123"); + await _stripeAdapter.Received(1).DeleteTaxIdAsync(customer.Id, "tax_id_123"); } [Fact] @@ -322,7 +321,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse @@ -334,7 +333,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -384,14 +383,14 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse )).Returns(customer); _stripeAdapter - .TaxIdCreateAsync(customer.Id, + .CreateTaxIdAsync(customer.Id, Arg.Is(options => options.Type == TaxIdType.EUVAT)) .Returns(new TaxId { Type = TaxIdType.EUVAT, Value = "ESA12345678" }); @@ -401,10 +400,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input with { TaxId = new TaxID(TaxIdType.EUVAT, "ESA12345678") }, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdCreateAsync(organization.GatewayCustomerId, Arg.Is( + await _stripeAdapter.Received(1).CreateTaxIdAsync(organization.GatewayCustomerId, Arg.Is( options => options.Type == TaxIdType.SpanishNIF && options.Value == input.TaxId.Value)); } diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs index 72280c4c77..da42127f33 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.Billing.Extensions; using Braintree; @@ -82,7 +81,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -144,7 +143,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -213,7 +212,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -232,7 +231,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("https://example.com", maskedBankAccount.HostedVerificationUrl); await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, Arg.Is(options => + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == string.Empty && options.Metadata[MetadataKeys.RetiredBraintreeCustomerId] == "braintree_customer_id")); } @@ -262,7 +261,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -291,7 +290,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); } @@ -315,7 +314,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -344,10 +343,10 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Address.Country == "US" && options.Address.PostalCode == "12345")); } @@ -468,7 +467,7 @@ public class UpdatePaymentMethodCommandTests var maskedPayPalAccount = maskedPaymentMethod.AsT2; Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id")); } diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs index b6b0d596b3..4e4c5199e2 100644 --- a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Braintree; using Microsoft.Extensions.Logging; @@ -166,7 +165,7 @@ public class GetPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))).Returns( new SetupIntent { diff --git a/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs index c7ab0c17ff..9ade4d0979 100644 --- a/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using NSubstitute; using NSubstitute.ReturnsExtensions; @@ -57,7 +56,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { @@ -162,7 +161,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { @@ -246,7 +245,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index cc9c409b4a..b58b5cd250 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -146,11 +146,11 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSetupIntent = Substitute.For(); mockSetupIntent.Id = "seti_123"; - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); - _stripeAdapter.SetupIntentList(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -158,8 +158,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -200,10 +200,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -211,8 +211,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -243,10 +243,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); @@ -255,8 +255,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); @@ -299,10 +299,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -356,8 +356,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Mock that the user has a payment method (this is the key difference from the credit purchase case) _hasPaymentMethodQuery.Run(Arg.Any()).Returns(true); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -365,7 +365,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any(), Arg.Any(), Arg.Any()); } @@ -415,8 +415,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests _updatePaymentMethodCommand.Run(Arg.Any(), Arg.Any(), Arg.Any()) .Returns(mockMaskedPaymentMethod); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -428,9 +428,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Verify GetCustomerOrThrow was called after updating payment method await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); // Verify no new customer was created - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); // Verify subscription was created - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); // Verify user was updated correctly Assert.True(user.Premium); await _userService.Received(1).SaveUserAsync(user); @@ -474,10 +474,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -525,10 +525,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -577,10 +577,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -628,13 +628,13 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SetupIntentList(Arg.Any()) + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()) .Returns(Task.FromResult(new List())); // Empty list - no setup intent found // Act @@ -681,8 +681,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -690,7 +690,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); Assert.True(user.Premium); Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate); } @@ -716,8 +716,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests Assert.True(result.IsT3); // Assuming T3 is the Unhandled result Assert.IsType(result.AsT3.Exception); // Verify no customer was created or subscription attempted - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateSubscriptionAsync(Arg.Any()); await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); } @@ -767,8 +767,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests ] }; - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); // Act var result = await _command.Run(user, paymentMethod, billingAddress, additionalStorage); diff --git a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs index d0b2eb7aa4..b5afaf65cd 100644 --- a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; @@ -50,7 +50,7 @@ public class PreviewPremiumTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -59,7 +59,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(3.00m, tax); Assert.Equal(33.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -84,7 +84,7 @@ public class PreviewPremiumTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(5, billingAddress); @@ -93,7 +93,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(5.00m, tax); Assert.Equal(55.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -120,7 +120,7 @@ public class PreviewPremiumTaxCommandTests Total = 2750 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -129,7 +129,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(2.50m, tax); Assert.Equal(27.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -154,7 +154,7 @@ public class PreviewPremiumTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(20, billingAddress); @@ -163,7 +163,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(8.00m, tax); Assert.Equal(88.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -190,7 +190,7 @@ public class PreviewPremiumTaxCommandTests Total = 4950 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(10, billingAddress); @@ -199,7 +199,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(4.50m, tax); Assert.Equal(49.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -226,7 +226,7 @@ public class PreviewPremiumTaxCommandTests Total = 3000 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -235,7 +235,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(0.00m, tax); Assert.Equal(30.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -260,7 +260,7 @@ public class PreviewPremiumTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(-5, billingAddress); @@ -269,7 +269,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(6.00m, tax); Assert.Equal(66.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "FR" && @@ -295,7 +295,7 @@ public class PreviewPremiumTaxCommandTests Total = 3123 // $31.23 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); diff --git a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs index 4060b45528..0ca1ecfe73 100644 --- a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs +++ b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs @@ -10,7 +10,6 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -179,7 +178,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -196,7 +195,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(7, capturedOptions.TrialPeriodDays); @@ -255,7 +254,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -272,7 +271,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(0, capturedOptions.TrialPeriodDays); @@ -329,7 +328,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -346,7 +345,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(7, capturedOptions.TrialPeriodDays); @@ -364,7 +363,7 @@ public class OrganizationBillingServiceTests CustomerUpdateOptions capturedOptions = null; sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(id => id == organization.GatewayCustomerId), Arg.Do(options => capturedOptions = options)) .Returns(new Customer()); @@ -375,7 +374,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .CustomerUpdateAsync( + .UpdateCustomerAsync( organization.GatewayCustomerId, Arg.Any()); @@ -401,7 +400,7 @@ public class OrganizationBillingServiceTests CustomerUpdateOptions capturedOptions = null; sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(id => id == organization.GatewayCustomerId), Arg.Do(options => capturedOptions = options)) .Returns(new Customer()); @@ -412,7 +411,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .CustomerUpdateAsync( + .UpdateCustomerAsync( organization.GatewayCustomerId, Arg.Any()); @@ -445,6 +444,6 @@ public class OrganizationBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } } diff --git a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs index 06a408c5a8..cd4c5effbe 100644 --- a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs +++ b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs @@ -1,9 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -19,7 +19,7 @@ public class PaymentHistoryServiceTests var subscriber = new Organization { GatewayCustomerId = "cus_id", GatewaySubscriptionId = "sub_id" }; var invoices = new List { new() { Id = "in_id" } }; var stripeAdapter = Substitute.For(); - stripeAdapter.InvoiceListAsync(Arg.Any()).Returns(invoices); + stripeAdapter.ListInvoicesAsync(Arg.Any()).Returns(invoices); var transactionRepository = Substitute.For(); var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository); @@ -29,7 +29,7 @@ public class PaymentHistoryServiceTests // Assert Assert.NotEmpty(result); Assert.Single(result); - await stripeAdapter.Received(1).InvoiceListAsync(Arg.Any()); + await stripeAdapter.Received(1).ListInvoicesAsync(Arg.Any()); } [Fact] diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs similarity index 96% rename from test/Core.Test/Services/StripePaymentServiceTests.cs rename to test/Core.Test/Billing/Services/StripePaymentServiceTests.cs index 8f556be57a..73f28113ca 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs @@ -1,7 +1,8 @@ using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -49,7 +50,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -100,7 +101,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -159,7 +160,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -198,7 +199,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -256,7 +257,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -295,7 +296,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -332,7 +333,7 @@ public class StripePaymentServiceTests }; sutProvider.GetDependency() - .SubscriptionGetAsync( + .GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Any()) .Returns(subscription); @@ -367,7 +368,7 @@ public class StripePaymentServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync( + .GetSubscriptionAsync( Arg.Any(), Arg.Any()) .Returns(subscription); @@ -376,7 +377,7 @@ public class StripePaymentServiceTests await sutProvider.Sut.GetSubscriptionAsync(subscriber); // Assert - Verify expand options are correct - await stripeAdapter.Received(1).SubscriptionGetAsync( + await stripeAdapter.Received(1).GetSubscriptionAsync( subscriber.GatewaySubscriptionId, Arg.Is(o => o.Expand.Contains("customer.discount.coupon.applies_to") && @@ -405,6 +406,6 @@ public class StripePaymentServiceTests // Verify no Stripe API calls were made await sutProvider.GetDependency() .DidNotReceive() - .SubscriptionGetAsync(Arg.Any(), Arg.Any()); + .GetSubscriptionAsync(Arg.Any(), Arg.Any()); } } diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 50fb160754..2f938065e5 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,10 +3,10 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Tax.Models; using Bit.Core.Enums; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -44,7 +44,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); await ThrowsBillingExceptionAsync(() => @@ -52,11 +52,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -81,7 +81,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -95,12 +95,12 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is( + .UpdateSubscriptionAsync(subscriptionId, Arg.Is( options => options.Metadata["cancellingUserId"] == userId.ToString())); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -127,7 +127,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -141,11 +141,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -170,7 +170,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -184,7 +184,7 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is(options => + .UpdateSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancelAtPeriodEnd == true && options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason && @@ -192,7 +192,7 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -223,7 +223,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -237,7 +237,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -253,7 +253,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomer(organization); @@ -287,7 +287,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); @@ -301,7 +301,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -318,7 +318,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); @@ -351,7 +351,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -388,7 +388,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -442,7 +442,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -478,7 +478,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -498,7 +498,7 @@ public class SubscriberServiceTests { var customer = new Customer { Id = provider.GatewayCustomerId }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -521,7 +521,7 @@ public class SubscriberServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntent.Id); - sutProvider.GetDependency().SetupIntentGet(setupIntent.Id, + sutProvider.GetDependency().GetSetupIntentAsync(setupIntent.Id, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(setupIntent); var paymentMethod = await sutProvider.Sut.GetPaymentSource(provider); @@ -541,7 +541,7 @@ public class SubscriberServiceTests DefaultSource = new BankAccount { Status = "verified", BankName = "Chase", Last4 = "9999" } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -564,7 +564,7 @@ public class SubscriberServiceTests DefaultSource = new Card { Brand = "Visa", Last4 = "9999", ExpMonth = 9, ExpYear = 2028 } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -596,7 +596,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -636,7 +636,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -650,7 +650,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -666,7 +666,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscription(organization); @@ -698,7 +698,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); @@ -712,7 +712,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -729,7 +729,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); @@ -760,7 +760,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (braintreeGateway, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -795,7 +795,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -832,7 +832,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -887,7 +887,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -946,21 +946,21 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List())); await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.Received(1).BankAccountDeleteAsync(stripeCustomer.Id, bankAccountId); + await stripeAdapter.Received(1).DeleteBankAccountAsync(stripeCustomer.Id, bankAccountId); - await stripeAdapter.Received(1).CardDeleteAsync(stripeCustomer.Id, cardId); + await stripeAdapter.Received(1).DeleteCardAsync(stripeCustomer.Id, cardId); await stripeAdapter.DidNotReceiveWithAnyArgs() - .PaymentMethodDetachAsync(Arg.Any(), Arg.Any()); + .DetachPaymentMethodAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -978,11 +978,11 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List { new () @@ -997,15 +997,15 @@ public class SubscriberServiceTests await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.DidNotReceiveWithAnyArgs().BankAccountDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteBankAccountAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.DidNotReceiveWithAnyArgs().CardDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteCardAsync(Arg.Any(), Arg.Any()); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(bankAccountId); + .DetachPaymentMethodAsync(bankAccountId); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(cardId); + .DetachPaymentMethodAsync(cardId); } private static async IAsyncEnumerable GetPaymentMethodsAsync( @@ -1050,7 +1050,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1062,7 +1062,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1076,10 +1076,10 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([new SetupIntent(), new SetupIntent()]); await ThrowsBillingExceptionAsync(() => @@ -1093,7 +1093,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1107,10 +1107,10 @@ public class SubscriberServiceTests var matchingSetupIntent = new SetupIntent { Id = "setup_intent_1" }; - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([matchingSetupIntent]); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); @@ -1119,12 +1119,12 @@ public class SubscriberServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_1"); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); } @@ -1135,7 +1135,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")) ) @@ -1148,22 +1148,22 @@ public class SubscriberServiceTests } }); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.Card, "TOKEN")); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).PaymentMethodAttachAsync("TOKEN", Arg.Is( + await stripeAdapter.Received(1).AttachPaymentMethodAsync("TOKEN", Arg.Is( options => options.Customer == provider.GatewayCustomerId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.InvoiceSettings.DefaultPaymentMethod == "TOKEN" && options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); @@ -1176,7 +1176,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1202,7 +1202,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1240,7 +1240,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1294,7 +1294,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1363,7 +1363,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId @@ -1395,7 +1395,7 @@ public class SubscriberServiceTests new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN"))); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -1405,7 +1405,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1442,7 +1442,7 @@ public class SubscriberServiceTests await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN")); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == braintreeCustomerId)); } @@ -1473,7 +1473,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1487,7 +1487,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "US" && @@ -1522,12 +1522,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1536,13 +1536,13 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -1555,7 +1555,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1569,7 +1569,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "CA" && @@ -1605,12 +1605,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1619,16 +1619,16 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -1655,7 +1655,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1669,7 +1669,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1678,12 +1678,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer()); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Returns(new Customer()); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } [Theory, BitAutoData] @@ -1693,12 +1693,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } #endregion @@ -1724,7 +1724,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1738,7 +1738,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1747,12 +1747,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } [Theory, BitAutoData] @@ -1762,12 +1762,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } #endregion diff --git a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs index 570f94575f..41f8839eb4 100644 --- a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs @@ -6,7 +6,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Entities; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -98,13 +97,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(organization).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(organization); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is((SubscriptionCreateOptions options) => + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is((SubscriptionCreateOptions options) => options.AutomaticTax.Enabled == true && options.CollectionMethod == CollectionMethod.ChargeAutomatically && options.Customer == "cus_123" && @@ -154,13 +153,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(provider).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(provider); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _providerRepository.Received(1).ReplaceAsync(Arg.Is(prov => prov.Id == providerId && @@ -199,13 +198,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(user).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(user); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => u.Id == userId && diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs index 786a6f6c0d..a6db6ae8fd 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,7 +13,7 @@ public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseT protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, OrganizationSponsorship sponsorship, SutProvider sutProvider) { - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); if (sponsorship != null) @@ -46,7 +47,7 @@ OrganizationSponsorship sponsorship, SutProvider sutProvider) protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .RemoveOrganizationSponsorshipAsync(default, default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .UpsertAsync(default); diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs index 69e7183c65..127cc7e502 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs @@ -1,10 +1,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -82,7 +82,7 @@ public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() .SponsorOrganizationAsync(default, default); await sutProvider.GetDependency() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs index fb64c11312..83e1487b01 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; @@ -54,7 +55,7 @@ public class AddSecretsManagerSubscriptionCommandTests c.AdditionalServiceAccounts == additionalServiceAccounts && c.AdditionalSeats == organization.Seats.GetValueOrDefault())); - await sutProvider.GetDependency().Received() + await sutProvider.GetDependency().Received() .AddSecretsManagerToSubscription(organization, plan, additionalSmSeats, additionalServiceAccounts); // TODO: call ReferenceEventService - see AC-1481 @@ -150,7 +151,7 @@ public class AddSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AddSecretsManagerToSubscription(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs index baa9e04c22..510433a2fa 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -86,9 +87,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -136,9 +137,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -258,7 +259,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( + await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( Arg.Is(o => o.Id == organizationId), plan, expectedSmServiceAccountsExcludingBase); @@ -779,9 +780,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustSmSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustServiceAccountsAsync(Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 await sutProvider.GetDependency().DidNotReceive() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs index 3841f7a619..8a00604bb0 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -121,7 +122,7 @@ public class UpgradeOrganizationPlanCommandTests Users = 1 }); await sutProvider.Sut.UpgradePlanAsync(organization.Id, organizationUpgrade); - await sutProvider.GetDependency().Received(1).AdjustSubscription( + await sutProvider.GetDependency().Received(1).AdjustSubscription( organization, MockPlans.Get(planType), organizationUpgrade.AdditionalSeats, diff --git a/test/Core.Test/Utilities/EventIntegrationsCacheConstantsTests.cs b/test/Core.Test/Utilities/EventIntegrationsCacheConstantsTests.cs index f6084c9209..a87392c2c1 100644 --- a/test/Core.Test/Utilities/EventIntegrationsCacheConstantsTests.cs +++ b/test/Core.Test/Utilities/EventIntegrationsCacheConstantsTests.cs @@ -55,20 +55,6 @@ public class EventIntegrationsCacheConstantsTests Assert.NotEqual(keyWithEvent, keyWithDifferentIntegration); Assert.NotEqual(keyWithEvent, keyWithDifferentOrganization); Assert.Equal(keyWithEvent, keyWithSameDetails); - - var expectedWithNullEvent = $"OrganizationIntegrationConfigurationDetails:{orgId:N}:Hec:"; - var keyWithNullEvent = EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( - orgId, integrationType, null); - var keyWithNullEventDifferentIntegration = EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( - orgId, IntegrationType.Webhook, null); - var keyWithNullEventDifferentOrganization = EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( - Guid.NewGuid(), integrationType, null); - - Assert.Equal(expectedWithNullEvent, keyWithNullEvent); - Assert.NotEqual(keyWithEvent, keyWithNullEvent); - Assert.NotEqual(keyWithNullEvent, keyWithDifferentEvent); - Assert.NotEqual(keyWithNullEvent, keyWithNullEventDifferentIntegration); - Assert.NotEqual(keyWithNullEvent, keyWithNullEventDifferentOrganization); } [Theory, BitAutoData] diff --git a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs index 8c53ed42a4..370ffd34aa 100644 --- a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs @@ -112,13 +112,6 @@ public class BaseRequestValidatorTests .Returns(true); } - private void SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(bool recoveryCodeSupportEnabled) - { - _featureService - .IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers) - .Returns(recoveryCodeSupportEnabled); - } - /* Logic path * ValidateAsync -> UpdateFailedAuthDetailsAsync -> _mailService.SendFailedLoginAttemptsEmailAsync * |-> BuildErrorResultAsync -> _eventService.LogUserEventAsync @@ -126,16 +119,14 @@ public class BaseRequestValidatorTests * |-> SetErrorResult */ [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_ContextNotValid_SelfHosted_ShouldBuildErrorResult_ShouldLogWarning( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _globalSettings.SelfHosted = true; _sut.isValid = false; @@ -152,16 +143,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_DeviceNotValidated_ShouldLogError( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass @@ -194,16 +183,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_DeviceValidated_ShouldSucceed( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass @@ -240,16 +227,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_ValidatedAuthRequest_ConsumedOnSuccess( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass @@ -305,16 +290,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_ValidatedAuthRequest_NotConsumed_When2faRequired( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); // 1 -> to pass @@ -337,7 +320,8 @@ public class BaseRequestValidatorTests // 2 -> will result to false with no extra configuration // 3 -> set two factor to be required - requestContext.User.TwoFactorProviders = "{\"1\":{\"Enabled\":true,\"MetaData\":{\"Email\":\"user@test.dev\"}}}"; + requestContext.User.TwoFactorProviders = + "{\"1\":{\"Enabled\":true,\"MetaData\":{\"Email\":\"user@test.dev\"}}}"; _twoFactorAuthenticationValidator .RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(true, null))); @@ -347,7 +331,7 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Dictionary { { "TwoFactorProviders", new[] { "0", "1" } }, - { "TwoFactorProviders2", new Dictionary{{"Email", null}} } + { "TwoFactorProviders2", new Dictionary { { "Email", null } } } })); // Act @@ -364,16 +348,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_TwoFactorTokenInvalid_ShouldSendFailedTwoFactorEmail( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -408,16 +390,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_TwoFactorRememberTokenExpired_ShouldNotSendFailedTwoFactorEmail( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -463,21 +443,17 @@ public class BaseRequestValidatorTests // Test grantTypes that require SSO when a user is in an organization that requires it [Theory] - [BitAutoData("password", true)] - [BitAutoData("password", false)] - [BitAutoData("webauthn", true)] - [BitAutoData("webauthn", false)] - [BitAutoData("refresh_token", true)] - [BitAutoData("refresh_token", false)] + [BitAutoData("password")] + [BitAutoData("webauthn")] + [BitAutoData("refresh_token")] public async Task ValidateAsync_GrantTypes_OrgSsoRequiredTrue_ShouldSetSsoResult( string grantType, - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -497,21 +473,17 @@ public class BaseRequestValidatorTests // Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled [Theory] - [BitAutoData("password", true)] - [BitAutoData("password", false)] - [BitAutoData("webauthn", true)] - [BitAutoData("webauthn", false)] - [BitAutoData("refresh_token", true)] - [BitAutoData("refresh_token", false)] + [BitAutoData("password")] + [BitAutoData("webauthn")] + [BitAutoData("refresh_token")] public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredTrue_ShouldSetSsoResult( string grantType, - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -533,21 +505,17 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData("password", true)] - [BitAutoData("password", false)] - [BitAutoData("webauthn", true)] - [BitAutoData("webauthn", false)] - [BitAutoData("refresh_token", true)] - [BitAutoData("refresh_token", false)] + [BitAutoData("password")] + [BitAutoData("webauthn")] + [BitAutoData("refresh_token")] public async Task ValidateAsync_GrantTypes_WithPolicyRequirementsEnabled_OrgSsoRequiredFalse_ShouldSucceed( string grantType, - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -582,21 +550,17 @@ public class BaseRequestValidatorTests // Test grantTypes where SSO would be required but the user is not in an // organization that requires it [Theory] - [BitAutoData("password", true)] - [BitAutoData("password", false)] - [BitAutoData("webauthn", true)] - [BitAutoData("webauthn", false)] - [BitAutoData("refresh_token", true)] - [BitAutoData("refresh_token", false)] + [BitAutoData("password")] + [BitAutoData("webauthn")] + [BitAutoData("refresh_token")] public async Task ValidateAsync_GrantTypes_OrgSsoRequiredFalse_ShouldSucceed( string grantType, - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -631,19 +595,17 @@ public class BaseRequestValidatorTests // Test the grantTypes where SSO is in progress or not relevant [Theory] - [BitAutoData("authorization_code", true)] - [BitAutoData("authorization_code", false)] - [BitAutoData("client_credentials", true)] - [BitAutoData("client_credentials", false)] + [BitAutoData("authorization_code")] + [BitAutoData("client_credentials")] + [BitAutoData("client_credentials")] public async Task ValidateAsync_GrantTypes_SsoRequiredFalse_ShouldSucceed( string grantType, - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -679,16 +641,14 @@ public class BaseRequestValidatorTests * ValidateAsync -> UserService.IsLegacyUser -> FailAuthForLegacyUserAsync */ [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_IsLegacyUser_FailAuthForLegacyUserAsync( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = context.CustomValidatorRequestContext.User; user.Key = null; @@ -713,16 +673,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_CustomResponse_NoMasterPassword_ShouldSetUserDecryptionOptions( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); @@ -763,19 +721,16 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true, KdfType.PBKDF2_SHA256, 654_321, null, null)] - [BitAutoData(false, KdfType.PBKDF2_SHA256, 654_321, null, null)] - [BitAutoData(true, KdfType.Argon2id, 11, 128, 5)] - [BitAutoData(false, KdfType.Argon2id, 11, 128, 5)] + [BitAutoData(KdfType.PBKDF2_SHA256, 654_321, null, null)] + [BitAutoData(KdfType.Argon2id, 11, 128, 5)] public async Task ValidateAsync_CustomResponse_MasterPassword_ShouldSetUserDecryptionOptions( - bool featureFlagValue, KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); _userDecryptionOptionsBuilder.ForUser(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithDevice(Arg.Any()).Returns(_userDecryptionOptionsBuilder); _userDecryptionOptionsBuilder.WithSso(Arg.Any()).Returns(_userDecryptionOptionsBuilder); @@ -834,16 +789,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_CustomResponse_ShouldIncludeAccountKeys( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var mockAccountKeys = new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -916,16 +869,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_CustomResponse_AccountKeysQuery_SkippedWhenPrivateKeyIsNull( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); requestContext.User.PrivateKey = null; var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -946,16 +897,14 @@ public class BaseRequestValidatorTests } [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_CustomResponse_AccountKeysQuery_CalledWithCorrectUser( - bool featureFlagValue, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagValue); var expectedUser = requestContext.User; _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData @@ -995,22 +944,20 @@ public class BaseRequestValidatorTests /// Tests the core PM-21153 feature: SSO-required users can use recovery codes to disable 2FA, /// but must then authenticate via SSO with a descriptive message about the recovery. /// This test validates: - /// 1. Validation order is changed (2FA before SSO) when recovery code is provided + /// 1. Validation order prioritizes 2FA before SSO when recovery code is provided /// 2. Recovery code successfully validates and sets TwoFactorRecoveryRequested flag /// 3. SSO validation then fails with recovery-specific message /// 4. User is NOT logged in (must authenticate via IdP) /// [Theory] - [BitAutoData(true)] // Feature flag ON - new behavior - [BitAutoData(false)] // Feature flag OFF - should fail at SSO before 2FA recovery + [BitAutoData] public async Task ValidateAsync_RecoveryCodeForSsoRequiredUser_BlocksWithDescriptiveMessage( - bool featureFlagEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -1023,8 +970,8 @@ public class BaseRequestValidatorTests // 2. SSO is required (this user is in an org that requires SSO) _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); // 3. 2FA is required _twoFactorAuthenticationValidator @@ -1048,30 +995,16 @@ public class BaseRequestValidatorTests var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - if (featureFlagEnabled) - { - // NEW BEHAVIOR: Recovery succeeds, then SSO blocks with descriptive message - Assert.Equal( - "Two-factor recovery has been performed. SSO authentication is required.", - errorResponse.Message); + // Recovery succeeds, then SSO blocks with descriptive message + Assert.Equal( + "Two-factor recovery has been performed. SSO authentication is required.", + errorResponse.Message); - // Verify recovery was marked - Assert.True(requestContext.TwoFactorRecoveryRequested, - "TwoFactorRecoveryRequested flag should be set"); - } - else - { - // LEGACY BEHAVIOR: SSO blocks BEFORE recovery can happen - Assert.Equal( - "SSO authentication is required.", - errorResponse.Message); + // Verify recovery was marked + Assert.True(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested flag should be set"); - // Recovery never happened because SSO checked first - Assert.False(requestContext.TwoFactorRecoveryRequested, - "TwoFactorRecoveryRequested should be false (SSO blocked first)"); - } - - // In both cases: User is NOT logged in + // User is NOT logged in await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn); } @@ -1086,16 +1019,14 @@ public class BaseRequestValidatorTests /// 4. NOT be logged in /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_InvalidRecoveryCodeForSsoRequiredUser_FailsAt2FA( - bool featureFlagEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -1104,8 +1035,8 @@ public class BaseRequestValidatorTests // 2. SSO is required _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(true)); // 3. 2FA is required _twoFactorAuthenticationValidator @@ -1129,51 +1060,32 @@ public class BaseRequestValidatorTests var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - if (featureFlagEnabled) - { - // NEW BEHAVIOR: 2FA is checked first (due to recovery code request), fails with 2FA error - Assert.Equal( - "Two-step token is invalid. Try again.", - errorResponse.Message); + // 2FA is checked first (due to recovery code request), fails with 2FA error + Assert.Equal( + "Two-step token is invalid. Try again.", + errorResponse.Message); - // Recovery was attempted but failed - flag should NOT be set - Assert.False(requestContext.TwoFactorRecoveryRequested, - "TwoFactorRecoveryRequested should be false (recovery failed)"); + // Recovery was attempted but failed - flag should NOT be set + Assert.False(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested should be false (recovery failed)"); - // Verify failed 2FA email was sent - await _mailService.Received(1).SendFailedTwoFactorAttemptEmailAsync( - user.Email, - TwoFactorProviderType.RecoveryCode, - Arg.Any(), - Arg.Any()); + // Verify failed 2FA email was sent + await _mailService.Received(1).SendFailedTwoFactorAttemptEmailAsync( + user.Email, + TwoFactorProviderType.RecoveryCode, + Arg.Any(), + Arg.Any()); - // Verify failed login event was logged - await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_FailedLogIn2fa); - } - else - { - // LEGACY BEHAVIOR: SSO is checked first, blocks before 2FA - Assert.Equal( - "SSO authentication is required.", - errorResponse.Message); + // Verify failed login event was logged + await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_FailedLogIn2fa); - // 2FA validation never happened - await _mailService.DidNotReceive().SendFailedTwoFactorAttemptEmailAsync( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any()); - } - // In both cases: User is NOT logged in + // User is NOT logged in await _eventService.DidNotReceive().LogUserEventAsync(user.Id, EventType.User_LoggedIn); // Verify user failed login count was updated (in new behavior path) - if (featureFlagEnabled) - { - await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => - u.Id == user.Id && u.FailedLoginCount > 0)); - } + await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => + u.Id == user.Id && u.FailedLoginCount > 0)); } /// @@ -1187,16 +1099,14 @@ public class BaseRequestValidatorTests /// This is the "happy path" for recovery code usage. /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_RecoveryCodeForNonSsoUser_SuccessfulLogin( - bool featureFlagEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(featureFlagEnabled); var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -1205,8 +1115,8 @@ public class BaseRequestValidatorTests // 2. SSO is NOT required (this is a regular user, not in SSO org) _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(false)); + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + .Returns(Task.FromResult(false)); // 3. 2FA is required _twoFactorAuthenticationValidator @@ -1243,7 +1153,8 @@ public class BaseRequestValidatorTests await _sut.ValidateAsync(context); // Assert - Assert.False(context.GrantResult.IsError, "Authentication should succeed for non-SSO user with valid recovery code"); + Assert.False(context.GrantResult.IsError, + "Authentication should succeed for non-SSO user with valid recovery code"); // Verify user successfully logged in await _eventService.Received(1).LogUserEventAsync(user.Id, EventType.User_LoggedIn); @@ -1252,19 +1163,9 @@ public class BaseRequestValidatorTests await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => u.Id == user.Id && u.FailedLoginCount == 0)); - if (featureFlagEnabled) - { - // NEW BEHAVIOR: Recovery flag should be set for audit purposes - Assert.True(requestContext.TwoFactorRecoveryRequested, - "TwoFactorRecoveryRequested flag should be set for audit/logging"); - } - else - { - // LEGACY BEHAVIOR: Recovery flag doesn't exist, but login still succeeds - // (SSO check happens before 2FA in legacy, but user is not SSO-required so both pass) - Assert.False(requestContext.TwoFactorRecoveryRequested, - "TwoFactorRecoveryRequested should be false in legacy mode"); - } + // Recovery flag should be set for audit purposes + Assert.True(requestContext.TwoFactorRecoveryRequested, + "TwoFactorRecoveryRequested flag should be set for audit/logging"); } [Theory] @@ -1308,16 +1209,14 @@ public class BaseRequestValidatorTests /// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach. /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation( - bool recoveryCodeFeatureEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1327,7 +1226,7 @@ public class BaseRequestValidatorTests // SSO is required via legacy path _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(true)); // Act @@ -1352,16 +1251,14 @@ public class BaseRequestValidatorTests /// instead of the legacy RequireSsoLoginAsync method. /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator( - bool recoveryCodeFeatureEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1371,9 +1268,9 @@ public class BaseRequestValidatorTests // Configure SsoRequestValidator to indicate SSO is required _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) + Arg.Any(), + Arg.Any(), + Arg.Any()) .Returns(Task.FromResult(false)); // false = SSO required // Set up the ValidationErrorResult that SsoRequestValidator would set @@ -1410,16 +1307,14 @@ public class BaseRequestValidatorTests /// authentication continues successfully through the new validation path. /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin( - bool recoveryCodeFeatureEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1430,9 +1325,9 @@ public class BaseRequestValidatorTests // SsoRequestValidator returns true (SSO not required) _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) + Arg.Any(), + Arg.Any(), + Arg.Any()) .Returns(Task.FromResult(true)); // No 2FA required @@ -1473,16 +1368,14 @@ public class BaseRequestValidatorTests /// (e.g., with organization identifier), that custom response is properly propagated to the result. /// [Theory] - [BitAutoData(true)] - [BitAutoData(false)] + [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse( - bool recoveryCodeFeatureEnabled, [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(recoveryCodeFeatureEnabled); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); _sut.isValid = true; @@ -1504,9 +1397,9 @@ public class BaseRequestValidatorTests var context = CreateContext(tokenRequest, requestContext, grantResult); _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) + Arg.Any(), + Arg.Any(), + Arg.Any()) .Returns(Task.FromResult(false)); // Act @@ -1516,7 +1409,8 @@ public class BaseRequestValidatorTests Assert.True(context.GrantResult.IsError); Assert.NotNull(context.GrantResult.CustomResponse); Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse); - Assert.Equal("test-org-identifier", context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]); + Assert.Equal("test-org-identifier", + context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]); } /// @@ -1527,11 +1421,11 @@ public class BaseRequestValidatorTests [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1552,7 +1446,7 @@ public class BaseRequestValidatorTests // SSO is required (legacy check) _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) + Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(true)); // Act @@ -1578,11 +1472,11 @@ public class BaseRequestValidatorTests [BitAutoData] public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, + [AuthFixtures.CustomValidatorRequestContext] + CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - SetupRecoveryCodeSupportForSsoRequiredUsersFeatureFlag(true); _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1611,13 +1505,16 @@ public class BaseRequestValidatorTests }; requestContext.CustomResponse = new Dictionary { - { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } + { + "ErrorModel", + new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") + } }; _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) + Arg.Any(), + Arg.Any(), + Arg.Any()) .Returns(Task.FromResult(false)); // Act diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs index a41cd43923..a9b3e6f7f0 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs @@ -227,7 +227,7 @@ public abstract class WebApplicationFactoryBase : WebApplicationFactory services.AddSingleton(); // Noop StripePaymentService - this could be changed to integrate with our Stripe test account - Replace(services, Substitute.For()); + Replace(services, Substitute.For()); Replace(services, Substitute.For()); });