diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 30fcf29206..fe82f9fbe6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -484,7 +484,7 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Trigger self-host build - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} script: | @@ -525,7 +525,7 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Trigger k8s deploy - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} script: | diff --git a/Directory.Build.props b/Directory.Build.props index 71303d3529..76f35e297e 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.9.2 + 2025.10.0 Bit.$(MSBuildProjectName) enable diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index aa19ad5382..aaf0050b63 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -12,7 +12,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; @@ -90,7 +90,7 @@ public class ProviderService : IProviderService _providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand; } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) { var owner = await _userService.GetUserByIdAsync(ownerUserId); if (owner == null) @@ -115,21 +115,7 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid owner."); } - if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) - { - throw new BadRequestException("Both address and postal code are required to set up your provider."); - } - - if (tokenizedPaymentSource is not - { - Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, - Token: not null and not "" - }) - { - throw new BadRequestException("A payment method is required to set up your provider."); - } - - var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); provider.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index 398674c7b6..c9851eb403 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; @@ -21,10 +22,8 @@ using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -38,6 +37,9 @@ using Subscription = Stripe.Subscription; namespace Bit.Commercial.Core.Billing.Providers.Services; +using static Constants; +using static StripeConstants; + public class ProviderBillingService( IBraintreeGateway braintreeGateway, IEventService eventService, @@ -51,8 +53,7 @@ public class ProviderBillingService( IProviderUserRepository providerUserRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService, - ITaxService taxService) + ISubscriberService subscriberService) : IProviderBillingService { public async Task AddExistingOrganization( @@ -61,10 +62,7 @@ public class ProviderBillingService( string key) { await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, - new SubscriptionUpdateOptions - { - CancelAtPeriodEnd = false - }); + new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, @@ -83,7 +81,7 @@ public class ProviderBillingService( var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; - if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft) + if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft) { await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); @@ -184,16 +182,8 @@ public class ProviderBillingService( { Items = [ - new SubscriptionItemOptions - { - Price = newPriceId, - Quantity = oldSubscriptionItem!.Quantity - }, - new SubscriptionItemOptions - { - Id = oldSubscriptionItem.Id, - Deleted = true - } + new SubscriptionItemOptions { Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity }, + new SubscriptionItemOptions { Id = oldSubscriptionItem.Id, Deleted = true } ] }; @@ -202,7 +192,8 @@ public class ProviderBillingService( // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan // 2. Assign PlanType & PlanName to Organization - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); + var providerOrganizations = + await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); var newPlan = await pricingClient.GetPlanOrThrow(newPlanType); @@ -213,6 +204,7 @@ public class ProviderBillingService( { throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); } + organization.PlanType = newPlanType; organization.Plan = newPlan.Name; await organizationRepository.ReplaceAsync(organization); @@ -228,15 +220,15 @@ public class ProviderBillingService( if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { - logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); + logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, + nameof(organization.GatewayCustomerId)); return; } - var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions - { - Expand = ["tax", "tax_ids"] - }); + var providerCustomer = + await subscriberService.GetCustomerOrThrow(provider, + new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); @@ -269,23 +261,18 @@ public class ProviderBillingService( } ] }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - }, - TaxIdData = providerTaxId == null ? null : - [ - new CustomerTaxIdDataOptions - { - Type = providerTaxId.Type, - Value = providerTaxId.Value - } - ] + Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, + TaxIdData = providerTaxId == null + ? null + : + [ + new CustomerTaxIdDataOptions { Type = providerTaxId.Type, Value = providerTaxId.Value } + ] }; - if (providerCustomer.Address is not { Country: Constants.CountryAbbreviations.UnitedStates }) + if (providerCustomer.Address is not { Country: CountryAbbreviations.UnitedStates }) { - customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + customerCreateOptions.TaxExempt = TaxExempt.Reverse; } var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); @@ -347,9 +334,9 @@ public class ProviderBillingService( .Where(pair => pair.subscription is { Status: - StripeConstants.SubscriptionStatus.Active or - StripeConstants.SubscriptionStatus.Trialing or - StripeConstants.SubscriptionStatus.PastDue + SubscriptionStatus.Active or + SubscriptionStatus.Trialing or + SubscriptionStatus.PastDue }).ToList(); if (active.Count == 0) @@ -474,37 +461,27 @@ public class ProviderBillingService( // Below the limit to above the limit (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) || // Above the limit to further above the limit - (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > currentlyAssignedSeatTotal); + (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal > currentlyAssignedSeatTotal); } public async Task SetupCustomer( Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) { - ArgumentNullException.ThrowIfNull(tokenizedPaymentSource); - - if (taxInfo is not - { - BillingAddressCountry: not null and not "", - BillingAddressPostalCode: not null and not "" - }) - { - logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id); - throw new BillingException(); - } - var options = new CustomerCreateOptions { Address = new AddressOptions { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode, + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + State = billingAddress.State }, + Coupon = !string.IsNullOrEmpty(provider.DiscountId) ? provider.DiscountId : null, Description = provider.DisplayBusinessName(), Email = provider.BillingEmail, InvoiceSettings = new CustomerInvoiceSettingsOptions @@ -520,93 +497,61 @@ public class ProviderBillingService( } ] }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - } + Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, + TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None }; - if (taxInfo.BillingAddressCountry is not Constants.CountryAbbreviations.UnitedStates) + if (billingAddress.TaxId != null) { - options.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) - { - var taxIdType = taxService.GetStripeTaxCode( - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - - if (taxIdType == null) - { - logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - options.TaxIdData = [ - new CustomerTaxIdDataOptions { Type = taxIdType, Value = taxInfo.TaxIdNumber } + new CustomerTaxIdDataOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value } ]; - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) + if (billingAddress.TaxId.Code == TaxIdType.SpanishNIF) { options.TaxIdData.Add(new CustomerTaxIdDataOptions { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{taxInfo.TaxIdNumber}" + Type = TaxIdType.EUVAT, + Value = $"ES{billingAddress.TaxId.Value}" }); } } - if (!string.IsNullOrEmpty(provider.DiscountId)) - { - options.Coupon = provider.DiscountId; - } - var braintreeCustomerId = ""; - if (tokenizedPaymentSource is not - { - Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, - Token: not null and not "" - }) - { - logger.LogError("Cannot create customer for provider ({ProviderID}) with invalid payment method", provider.Id); - throw new BillingException(); - } - - var (type, token) = tokenizedPaymentSource; - // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (type) + switch (paymentMethod.Type) { - case PaymentMethodType.BankAccount: + case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token })) + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions + { + PaymentMethod = paymentMethod.Token + })) .FirstOrDefault(); if (setupIntent == null) { - logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id); + logger.LogError( + "Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", + provider.Id); throw new BillingException(); } await setupIntentCache.Set(provider.Id, setupIntent.Id); break; } - case PaymentMethodType.Card: + case TokenizablePaymentMethodType.Card: { - options.PaymentMethod = token; - options.InvoiceSettings.DefaultPaymentMethod = token; + options.PaymentMethod = paymentMethod.Token; + options.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token; break; } - case PaymentMethodType.PayPal: + case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, paymentMethod.Token); options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } @@ -616,8 +561,7 @@ public class ProviderBillingService( { return await stripeAdapter.CustomerCreateAsync(options); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == - StripeConstants.ErrorCodes.TaxIdInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid) { await Revert(); throw new BadRequestException( @@ -632,9 +576,9 @@ public class ProviderBillingService( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (tokenizedPaymentSource.Type) + switch (paymentMethod.Type) { - case PaymentMethodType.BankAccount: + case TokenizablePaymentMethodType.BankAccount: { var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); await stripeAdapter.SetupIntentCancel(setupIntentId, @@ -642,7 +586,7 @@ public class ProviderBillingService( await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); break; } - case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): { await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); break; @@ -661,9 +605,10 @@ public class ProviderBillingService( var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - if (providerPlans == null || providerPlans.Count == 0) + if (providerPlans.Count == 0) { - logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); + logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", + provider.Id); throw new BillingException(); } @@ -676,7 +621,9 @@ public class ProviderBillingService( if (!providerPlan.IsConfigured()) { - logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", provider.Id, plan.Name); + logger.LogError( + "Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", + provider.Id, plan.Name); throw new BillingException(); } @@ -692,16 +639,14 @@ public class ProviderBillingService( var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntent = !string.IsNullOrEmpty(setupIntentId) - ? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions - { - Expand = ["payment_method"] - }) + ? await stripeAdapter.SetupIntentGet(setupIntentId, + new SetupIntentGetOptions { Expand = ["payment_method"] }) : null; var usePaymentMethod = !string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) || - (customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true) || - (setupIntent?.IsUnverifiedBankAccount() == true); + customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true || + setupIntent?.IsUnverifiedBankAccount() == true; int? trialPeriodDays = provider.Type switch { @@ -712,30 +657,28 @@ public class ProviderBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - CollectionMethod = usePaymentMethod ? - StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice, + CollectionMethod = + usePaymentMethod + ? CollectionMethod.ChargeAutomatically + : CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = usePaymentMethod ? null : 30, Items = subscriptionItemOptionsList, - Metadata = new Dictionary - { - { "providerId", provider.Id.ToString() } - }, + Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, - TrialPeriodDays = trialPeriodDays + ProrationBehavior = ProrationBehavior.CreateProrations, + TrialPeriodDays = trialPeriodDays, + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } }; - subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - try { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (subscription is { - Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing + Status: SubscriptionStatus.Active or SubscriptionStatus.Trialing }) { return subscription; @@ -749,9 +692,11 @@ public class ProviderBillingService( throw new BillingException(); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == + ErrorCodes.CustomerTaxLocationInvalid) { - throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid."); + throw new BadRequestException( + "Your location wasn't recognized. Please ensure your country and postal code are valid."); } } @@ -765,7 +710,7 @@ public class ProviderBillingService( subscriberService.UpdateTaxInformation(provider, taxInformation)); await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, - new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically }); + new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically }); } public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) @@ -865,13 +810,9 @@ public class ProviderBillingService( await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { - Items = [ - new SubscriptionItemOptions - { - Id = item.Id, - Price = priceId, - Quantity = newlySubscribedSeats - } + Items = + [ + new SubscriptionItemOptions { Id = item.Id, Price = priceId, Quantity = newlySubscribedSeats } ] }); @@ -894,7 +835,8 @@ public class ProviderBillingService( var plan = await pricingClient.GetPlanOrThrow(planType); return providerOrganizations - .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) + .Where(providerOrganization => providerOrganization.Plan == plan.Name && + providerOrganization.Status == OrganizationStatusType.Managed) .Sum(providerOrganization => providerOrganization.Seats ?? 0); } diff --git a/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs b/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs index 12c7f679bd..b73b358925 100644 --- a/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs +++ b/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs @@ -1,10 +1,13 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Commands.ServiceAccounts.Interfaces; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; namespace Bit.Commercial.Core.SecretsManager.Commands.ServiceAccounts; @@ -13,15 +16,21 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand private readonly IAccessPolicyRepository _accessPolicyRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly IEventService _eventService; + private readonly ICurrentContext _currentContext; public CreateServiceAccountCommand( IAccessPolicyRepository accessPolicyRepository, IOrganizationUserRepository organizationUserRepository, - IServiceAccountRepository serviceAccountRepository) + IServiceAccountRepository serviceAccountRepository, + IEventService eventService, + ICurrentContext currentContext) { _accessPolicyRepository = accessPolicyRepository; _organizationUserRepository = organizationUserRepository; _serviceAccountRepository = serviceAccountRepository; + _eventService = eventService; + _currentContext = currentContext; } public async Task CreateAsync(ServiceAccount serviceAccount, Guid userId) @@ -38,6 +47,7 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand Write = true, }; await _accessPolicyRepository.CreateManyAsync(new List { accessPolicy }); + await _eventService.LogServiceAccountPeopleEventAsync(user.Id, accessPolicy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType); return createdServiceAccount; } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index f2ba2fab8f..e61cf5f97e 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -9,7 +9,7 @@ using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; @@ -41,7 +41,7 @@ public class ProviderServiceTests public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null)); + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null, null)); Assert.Contains("Invalid owner.", exception.Message); } @@ -53,83 +53,12 @@ public class ProviderServiceTests userService.GetUserByIdAsync(user.Id).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null)); + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null, null)); Assert.Contains("Invalid token.", exception.Message); } [Theory, BitAutoData] - public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException( - User user, - Provider provider, - string key, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource, - [ProviderUser] ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - taxInfo.BillingAddressCountry = null; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); - - Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message); - } - - [Theory, BitAutoData] - public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException( - User user, - Provider provider, - string key, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource, - [ProviderUser] ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - - tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); - - Assert.Equal("A payment method is required to set up your provider.", exception.Message); - } - - [Theory, BitAutoData] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource, + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress, [ProviderUser] ProviderUser providerUser, SutProvider sutProvider) { @@ -149,7 +78,7 @@ public class ProviderServiceTests var providerBillingService = sutProvider.GetDependency(); var customer = new Customer { Id = "customer_id" }; - providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer); + providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer); var subscription = new Subscription { Id = "subscription_id" }; providerBillingService.SetupSubscription(provider).Returns(subscription); @@ -158,7 +87,7 @@ public class ProviderServiceTests var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource); + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, billingAddress); await sutProvider.GetDependency().Received().UpsertAsync(Arg.Is( p => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 54c0b82aa9..18c71364e6 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -1,5 +1,4 @@ using System.Globalization; -using System.Net; using Bit.Commercial.Core.Billing.Providers.Models; using Bit.Commercial.Core.Billing.Providers.Services; using Bit.Core.AdminConsole.Entities; @@ -10,18 +9,16 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -895,118 +892,53 @@ public class ProviderBillingServiceTests #region SetupCustomer [Theory, BitAutoData] - public async Task SetupCustomer_MissingCountry_ContactSupport( + public async Task SetupCustomer_NullPaymentMethod_ThrowsNullReferenceException( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + BillingAddress billingAddress) { - taxInfo.BillingAddressCountry = null; - - await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task SetupCustomer_MissingPostalCode_ContactSupport( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) - { - taxInfo.BillingAddressCountry = null; - - await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - - [Theory, BitAutoData] - public async Task SetupCustomer_NullPaymentSource_ThrowsArgumentNullException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, null)); - } - - [Theory, BitAutoData] - public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) - { - provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; - - - tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; - - await ThrowsBillingExceptionAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, null, billingAddress)); } [Theory, BitAutoData] public async Task SetupCustomer_WithBankAccount_Error_Reverts( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); - - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; stripeAdapter.SetupIntentList(Arg.Is(options => - options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Throws(); sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns("setup_intent_id"); await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); @@ -1020,45 +952,37 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithPayPal_Error_Reverts( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); - - - sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && o.Metadata["btCustomerId"] == "braintree_customer_id" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Throws(); await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); await sutProvider.GetDependency().Customer.Received(1).DeleteAsync("braintree_customer_id"); } @@ -1067,17 +991,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithBankAccount_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1087,31 +1005,30 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; stripeAdapter.SetupIntentList(Arg.Is(options => - options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); @@ -1122,17 +1039,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithPayPal_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1142,30 +1053,29 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" }; - - sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && o.Metadata["btCustomerId"] == "braintree_customer_id" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } @@ -1174,17 +1084,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithCard_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1194,28 +1098,26 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.PaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } @@ -1224,17 +1126,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithCard_ReverseCharge_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "FR"; // Non-US country to trigger reverse charge + billingAddress.TaxId = new TaxID("fr_siren", "123456789"); var stripeAdapter = sutProvider.GetDependency(); @@ -1244,55 +1140,51 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.PaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber && + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value && o.TaxExempt == StripeConstants.TaxExempt.Reverse)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } [Theory, BitAutoData] - public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( + public async Task SetupCustomer_WithInvalidTaxId_ThrowsBadRequestException( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + BillingAddress billingAddress) { provider.Name = "MSP"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "invalid_tax_id"); - taxInfo.BillingAddressCountry = "AD"; + var stripeAdapter = sutProvider.GetDependency(); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns((string)null); + stripeAdapter.CustomerCreateAsync(Arg.Any()) + .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } }); var actual = await Assert.ThrowsAsync(async () => - await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); - Assert.IsType(actual); - Assert.Equal("billingTaxIdTypeInferenceError", actual.Message); + Assert.Equal("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid.", actual.Message); } #endregion diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 0ee4aa53a9..c5e42cf9e3 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -53,6 +53,7 @@ services: - ./.data/postgres/log:/var/log/postgresql profiles: - postgres + - ef mysql: image: mysql:8.0 @@ -69,6 +70,7 @@ services: - mysql_dev_data:/var/lib/mysql profiles: - mysql + - ef mariadb: image: mariadb:10 @@ -76,13 +78,13 @@ services: - 4306:3306 environment: MARIADB_USER: maria - MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD} MARIADB_DATABASE: vault_dev MARIADB_RANDOM_ROOT_PASSWORD: "true" volumes: - mariadb_dev_data:/var/lib/mysql profiles: - mariadb + - ef idp: image: kenchan0130/simplesamlphp:1.19.8 @@ -99,7 +101,7 @@ services: - idp rabbitmq: - image: rabbitmq:4.1.0-management + image: rabbitmq:4.1.3-management container_name: rabbitmq ports: - "5672:5672" @@ -153,5 +155,6 @@ volumes: mssql_dev_data: postgres_dev_data: mysql_dev_data: + mariadb_dev_data: rabbitmq_data: redis_data: diff --git a/dev/migrate.ps1 b/dev/migrate.ps1 index 287a2d18ee..26caa87efd 100755 --- a/dev/migrate.ps1 +++ b/dev/migrate.ps1 @@ -70,7 +70,7 @@ Foreach ($item in @( @($mysql, "MySQL", "MySqlMigrations", "mySql", 2), # MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context. # However they can still be run independently for integration tests. - @($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3) + @($mariadb, "MariaDB", "MySqlMigrations", "mySql", 4) )) { if (!$item[0] -and !$all) { continue diff --git a/src/Api/AdminConsole/Controllers/EventsController.cs b/src/Api/AdminConsole/Controllers/EventsController.cs index 18199ad8f2..f868f0b3b6 100644 --- a/src/Api/AdminConsole/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Controllers/EventsController.cs @@ -30,6 +30,8 @@ public class EventsController : Controller private readonly ICurrentContext _currentContext; private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; + private readonly IServiceAccountRepository _serviceAccountRepository; + public EventsController( IUserService userService, @@ -39,7 +41,8 @@ public class EventsController : Controller IEventRepository eventRepository, ICurrentContext currentContext, ISecretRepository secretRepository, - IProjectRepository projectRepository) + IProjectRepository projectRepository, + IServiceAccountRepository serviceAccountRepository) { _userService = userService; _cipherRepository = cipherRepository; @@ -49,6 +52,7 @@ public class EventsController : Controller _currentContext = currentContext; _secretRepository = secretRepository; _projectRepository = projectRepository; + _serviceAccountRepository = serviceAccountRepository; } [HttpGet("")] @@ -184,6 +188,57 @@ public class EventsController : Controller return new ListResponseModel(responses, result.ContinuationToken); } + [HttpGet("~/organization/{orgId}/service-account/{id}/events")] + public async Task> GetServiceAccounts( + Guid orgId, + Guid id, + [FromQuery] DateTime? start = null, + [FromQuery] DateTime? end = null, + [FromQuery] string continuationToken = null) + { + if (id == Guid.Empty || orgId == Guid.Empty) + { + throw new NotFoundException(); + } + + var serviceAccount = await GetServiceAccount(id, orgId); + var org = _currentContext.GetOrganization(orgId); + + if (org == null || !await _currentContext.AccessEventLogs(org.Id)) + { + throw new NotFoundException(); + } + + var (fromDate, toDate) = ApiHelpers.GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationServiceAccountAsync( + serviceAccount.OrganizationId, + serviceAccount.Id, + fromDate, + toDate, + new PageOptions { ContinuationToken = continuationToken }); + + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [ApiExplorerSettings(IgnoreApi = true)] + private async Task GetServiceAccount(Guid serviceAccountId, Guid orgId) + { + var serviceAccount = await _serviceAccountRepository.GetByIdAsync(serviceAccountId); + if (serviceAccount != null) + { + return serviceAccount; + } + + var fallbackServiceAccount = new ServiceAccount + { + Id = serviceAccountId, + OrganizationId = orgId + }; + + return fallbackServiceAccount; + } + [HttpGet("~/organizations/{orgId}/users/{id}/events")] public async Task> GetOrganizationUser(string orgId, string id, [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index a1815fd3bf..aa87bf9c74 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Context; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; @@ -93,22 +92,12 @@ public class ProvidersController : Controller var userId = _userService.GetProperUserId(User).Value; - var taxInfo = new TaxInfo - { - BillingAddressCountry = model.TaxInfo.Country, - BillingAddressPostalCode = model.TaxInfo.PostalCode, - TaxIdNumber = model.TaxInfo.TaxId, - BillingAddressLine1 = model.TaxInfo.Line1, - BillingAddressLine2 = model.TaxInfo.Line2, - BillingAddressCity = model.TaxInfo.City, - BillingAddressState = model.TaxInfo.State - }; - - var tokenizedPaymentSource = model.PaymentSource?.ToDomain(); + var paymentMethod = model.PaymentMethod.ToDomain(); + var billingAddress = model.BillingAddress.ToDomain(); var response = await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, - taxInfo, tokenizedPaymentSource); + paymentMethod, billingAddress); return new ProviderResponseModel(response); } diff --git a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs b/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs index 6e3751c6f6..c8ff4f9f7c 100644 --- a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs +++ b/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; +using System.Text.Json; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Core; using Bit.Core.AdminConsole.Entities; @@ -18,25 +15,58 @@ using Microsoft.AspNetCore.Mvc; namespace Bit.Api.AdminConsole.Controllers; [RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] -[Route("organizations/{organizationId:guid}/integrations/slack")] +[Route("organizations")] [Authorize("Application")] public class SlackIntegrationController( ICurrentContext currentContext, IOrganizationIntegrationRepository integrationRepository, - ISlackService slackService) : Controller + ISlackService slackService, + TimeProvider timeProvider) : Controller { - [HttpGet("redirect")] + [HttpGet("{organizationId:guid}/integrations/slack/redirect")] public async Task RedirectAsync(Guid organizationId) { if (!await currentContext.OrganizationOwner(organizationId)) { throw new NotFoundException(); } - string callbackUrl = Url.RouteUrl( - nameof(CreateAsync), - new { organizationId }, - currentContext.HttpContext.Request.Scheme); - var redirectUrl = slackService.GetRedirectUrl(callbackUrl); + + string? callbackUrl = Url.RouteUrl( + routeName: nameof(CreateAsync), + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integration = integrations.FirstOrDefault(i => i.Type == IntegrationType.Slack); + + if (integration is null) + { + // No slack integration exists, create Initiated version + integration = await integrationRepository.CreateAsync(new OrganizationIntegration + { + OrganizationId = organizationId, + Type = IntegrationType.Slack, + Configuration = null, + }); + } + else if (integration.Configuration is not null) + { + // A Completed (fully configured) Slack integration already exists, throw to prevent overriding + throw new BadRequestException("There already exists a Slack integration for this organization"); + + } // An Initiated slack integration exits, re-use it and kick off a new OAuth flow + + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + var redirectUrl = slackService.GetRedirectUrl( + callbackUrl: callbackUrl, + state: state.ToString() + ); if (string.IsNullOrEmpty(redirectUrl)) { @@ -46,23 +76,42 @@ public class SlackIntegrationController( return Redirect(redirectUrl); } - [HttpGet("create", Name = nameof(CreateAsync))] - public async Task CreateAsync(Guid organizationId, [FromQuery] string code) + [HttpGet("integrations/slack/create", Name = nameof(CreateAsync))] + [AllowAnonymous] + public async Task CreateAsync([FromQuery] string code, [FromQuery] string state) { - if (!await currentContext.OrganizationOwner(organizationId)) + var oAuthState = IntegrationOAuthState.FromString(state: state, timeProvider: timeProvider); + if (oAuthState is null) { throw new NotFoundException(); } - if (string.IsNullOrEmpty(code)) + // Fetch existing Initiated record + var integration = await integrationRepository.GetByIdAsync(oAuthState.IntegrationId); + if (integration is null || + integration.Type != IntegrationType.Slack || + integration.Configuration is not null) { - throw new BadRequestException("Missing code from Slack."); + throw new NotFoundException(); } - string callbackUrl = Url.RouteUrl( - nameof(CreateAsync), - new { organizationId }, - currentContext.HttpContext.Request.Scheme); + // Verify Organization matches hash + if (!oAuthState.ValidateOrg(integration.OrganizationId)) + { + throw new NotFoundException(); + } + + // Fetch token from Slack and store to DB + string? callbackUrl = Url.RouteUrl( + routeName: nameof(CreateAsync), + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } var token = await slackService.ObtainTokenViaOAuth(code, callbackUrl); if (string.IsNullOrEmpty(token)) @@ -70,14 +119,10 @@ public class SlackIntegrationController( throw new BadRequestException("Invalid response from Slack."); } - var integration = await integrationRepository.CreateAsync(new OrganizationIntegration - { - OrganizationId = organizationId, - Type = IntegrationType.Slack, - Configuration = JsonSerializer.Serialize(new SlackIntegration(token)), - }); - var location = $"/organizations/{organizationId}/integrations/{integration.Id}"; + integration.Configuration = JsonSerializer.Serialize(new SlackIntegration(token)); + await integrationRepository.UpsertAsync(integration); + var location = $"/organizations/{integration.OrganizationId}/integrations/{integration.Id}"; return Created(location, new OrganizationIntegrationResponseModel(integration)); } } diff --git a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs index 1f50c384a3..41cebe8b9b 100644 --- a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -3,8 +3,7 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; -using Bit.Api.Billing.Models.Requests; -using Bit.Api.Models.Request; +using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Utilities; @@ -28,8 +27,9 @@ public class ProviderSetupRequestModel [Required] public string Key { get; set; } [Required] - public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } - public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } + public MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; } + [Required] + public BillingAddressRequest BillingAddress { get; set; } public virtual Provider ToProvider(Provider provider) { diff --git a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs b/src/Api/AdminConsole/Models/Response/EventResponseModel.cs index bf02d8b00f..c259bc3bc4 100644 --- a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/EventResponseModel.cs @@ -35,6 +35,7 @@ public class EventResponseModel : ResponseModel SecretId = ev.SecretId; ProjectId = ev.ProjectId; ServiceAccountId = ev.ServiceAccountId; + GrantedServiceAccountId = ev.GrantedServiceAccountId; } public EventType Type { get; set; } @@ -58,4 +59,5 @@ public class EventResponseModel : ResponseModel public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs index f062ff46a2..5368f78e39 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs @@ -2,8 +2,6 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -#nullable enable - namespace Bit.Api.AdminConsole.Models.Response.Organizations; public class OrganizationIntegrationResponseModel : ResponseModel @@ -21,4 +19,29 @@ public class OrganizationIntegrationResponseModel : ResponseModel public Guid Id { get; set; } public IntegrationType Type { get; set; } public string? Configuration { get; set; } + + public OrganizationIntegrationStatus Status => Type switch + { + // Not yet implemented, shouldn't be present, NotApplicable + IntegrationType.CloudBillingSync => OrganizationIntegrationStatus.NotApplicable, + IntegrationType.Scim => OrganizationIntegrationStatus.NotApplicable, + + // Webhook is allowed to be null. If it's present, it's Completed + IntegrationType.Webhook => OrganizationIntegrationStatus.Completed, + + // If present and the configuration is null, OAuth has been initiated, and we are + // waiting on the return call + IntegrationType.Slack => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Initiated + : OrganizationIntegrationStatus.Completed, + + // HEC and Datadog should only be allowed to be created non-null. + // If they are null, they are Invalid + IntegrationType.Hec => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Invalid + : OrganizationIntegrationStatus.Completed, + IntegrationType.Datadog => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Invalid + : OrganizationIntegrationStatus.Completed, + }; } diff --git a/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs b/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs index fcf386d7ee..349bdebb88 100644 --- a/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs +++ b/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs @@ -121,7 +121,7 @@ public class SsoConfigurationDataRequest : IValidatableObject new[] { nameof(IdpEntityId) }); } - if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) + if (string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) { yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), new[] { nameof(IdpSingleSignOnServiceUrl) }); @@ -139,6 +139,7 @@ public class SsoConfigurationDataRequest : IValidatableObject new[] { nameof(IdpSingleLogoutServiceUrl) }); } + // TODO: On server, make public certificate required for SAML2 SSO: https://bitwarden.atlassian.net/browse/PM-26028 if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) { // Validate the certificate is in a valid format diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 21b17bff67..1d6bf51661 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -1,16 +1,8 @@ -#nullable enable -using System.Diagnostics; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.Billing.Models.Requests; +using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; -using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Repositories; using Bit.Core.Services; @@ -28,10 +20,8 @@ public class OrganizationBillingController( IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, IPaymentService paymentService, - IPricingClient pricingClient, ISubscriberService subscriberService, - IPaymentHistoryService paymentHistoryService, - IUserService userService) : BaseBillingController + IPaymentHistoryService paymentHistoryService) : BaseBillingController { [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) @@ -264,71 +254,6 @@ public class OrganizationBillingController( return TypedResults.Ok(); } - [HttpPost("restart-subscription")] - public async Task RestartSubscriptionAsync([FromRoute] Guid organizationId, - [FromBody] OrganizationCreateRequestModel model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - if (organization == null) - { - return Error.NotFound(); - } - var existingPlan = organization.PlanType; - var organizationSignup = model.ToOrganizationSignup(user); - var sale = OrganizationSale.From(organization, organizationSignup); - var plan = await pricingClient.GetPlanOrThrow(model.PlanType); - sale.Organization.PlanType = plan.Type; - sale.Organization.Plan = plan.Name; - sale.SubscriptionSetup.SkipTrial = true; - if (existingPlan == PlanType.Free && organization.GatewaySubscriptionId is not null) - { - sale.Organization.UseTotp = plan.HasTotp; - sale.Organization.UseGroups = plan.HasGroups; - sale.Organization.UseDirectory = plan.HasDirectory; - sale.Organization.SelfHost = plan.HasSelfHost; - sale.Organization.UsersGetPremium = plan.UsersGetPremium; - sale.Organization.UseEvents = plan.HasEvents; - sale.Organization.Use2fa = plan.Has2fa; - sale.Organization.UseApi = plan.HasApi; - sale.Organization.UsePolicies = plan.HasPolicies; - sale.Organization.UseSso = plan.HasSso; - sale.Organization.UseResetPassword = plan.HasResetPassword; - sale.Organization.UseKeyConnector = plan.HasKeyConnector ? organization.UseKeyConnector : false; - sale.Organization.UseScim = plan.HasScim; - sale.Organization.UseCustomPermissions = plan.HasCustomPermissions; - sale.Organization.UseOrganizationDomains = plan.HasOrganizationDomains; - sale.Organization.MaxCollections = plan.PasswordManager.MaxCollections; - } - - if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken)) - { - return Error.BadRequest("A payment method is required to restart the subscription."); - } - var org = await organizationRepository.GetByIdAsync(organizationId); - Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine."); - var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); - var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); - await organizationBillingService.Finalize(sale); - var updatedOrg = await organizationRepository.GetByIdAsync(organizationId); - if (updatedOrg != null) - { - await organizationBillingService.UpdatePaymentMethod(updatedOrg, paymentSource, taxInformation); - } - - return TypedResults.Ok(); - } - [HttpPost("setup-business-unit")] [SelfHosted(NotSelfHostedOnly = true)] public async Task SetupBusinessUnitAsync( diff --git a/src/Api/Billing/Controllers/TaxController.cs b/src/Api/Billing/Controllers/TaxController.cs index d2c1c36726..4ead414589 100644 --- a/src/Api/Billing/Controllers/TaxController.cs +++ b/src/Api/Billing/Controllers/TaxController.cs @@ -1,33 +1,73 @@ -using Bit.Api.Billing.Models.Requests; -using Bit.Core.Billing.Tax.Commands; +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Tax; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Organizations.Commands; +using Bit.Core.Billing.Premium.Commands; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; namespace Bit.Api.Billing.Controllers; [Authorize("Application")] -[Route("tax")] +[Route("billing/tax")] public class TaxController( - IPreviewTaxAmountCommand previewTaxAmountCommand) : BaseBillingController + IPreviewOrganizationTaxCommand previewOrganizationTaxCommand, + IPreviewPremiumTaxCommand previewPremiumTaxCommand) : BaseBillingController { - [HttpPost("preview-amount/organization-trial")] - public async Task PreviewTaxAmountForOrganizationTrialAsync( - [FromBody] PreviewTaxAmountForOrganizationTrialRequestBody requestBody) + [HttpPost("organizations/subscriptions/purchase")] + public async Task PreviewOrganizationSubscriptionPurchaseTaxAsync( + [FromBody] PreviewOrganizationSubscriptionPurchaseTaxRequest request) { - var parameters = new OrganizationTrialParameters + var (purchase, billingAddress) = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(purchase, billingAddress); + return Handle(result.Map(pair => new { - PlanType = requestBody.PlanType, - ProductType = requestBody.ProductType, - TaxInformation = new OrganizationTrialParameters.TaxInformationDTO - { - Country = requestBody.TaxInformation.Country, - PostalCode = requestBody.TaxInformation.PostalCode, - TaxId = requestBody.TaxInformation.TaxId - } - }; + pair.Tax, + pair.Total + })); + } - var result = await previewTaxAmountCommand.Run(parameters); + [HttpPost("organizations/{organizationId:guid}/subscription/plan-change")] + [InjectOrganization] + public async Task PreviewOrganizationSubscriptionPlanChangeTaxAsync( + [BindNever] Organization organization, + [FromBody] PreviewOrganizationSubscriptionPlanChangeTaxRequest request) + { + var (planChange, billingAddress) = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(organization, planChange, billingAddress); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); + } - return Handle(result); + [HttpPut("organizations/{organizationId:guid}/subscription/update")] + [InjectOrganization] + public async Task PreviewOrganizationSubscriptionUpdateTaxAsync( + [BindNever] Organization organization, + [FromBody] PreviewOrganizationSubscriptionUpdateTaxRequest request) + { + var update = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(organization, update); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); + } + + [HttpPost("premium/subscriptions/purchase")] + public async Task PreviewPremiumSubscriptionPurchaseTaxAsync( + [FromBody] PreviewPremiumSubscriptionPurchaseTaxRequest request) + { + var (purchase, billingAddress) = request.ToDomain(); + var result = await previewPremiumTaxCommand.Run(purchase, billingAddress); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); } } diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index a996290507..b01b629e4f 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Core; @@ -67,7 +66,7 @@ public class AccountBillingVNextController( } [HttpPost("subscription")] - [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)] [InjectUser] public async Task CreateSubscriptionAsync( [BindNever] User user, diff --git a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs index ee98031dbc..2f825f2cb9 100644 --- a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs @@ -2,11 +2,14 @@ using Bit.Api.AdminConsole.Authorization.Requirements; using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.Subscriptions; using Bit.Api.Billing.Models.Requirements; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -24,6 +27,7 @@ public class OrganizationBillingVNextController( IGetCreditQuery getCreditQuery, IGetOrganizationWarningsQuery getOrganizationWarningsQuery, IGetPaymentMethodQuery getPaymentMethodQuery, + IRestartSubscriptionCommand restartSubscriptionCommand, IUpdateBillingAddressCommand updateBillingAddressCommand, IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController { @@ -95,6 +99,20 @@ public class OrganizationBillingVNextController( return Handle(result); } + [Authorize] + [HttpPost("subscription/restart")] + [InjectOrganization] + public async Task RestartSubscriptionAsync( + [BindNever] Organization organization, + [FromBody] RestartSubscriptionRequest request) + { + var (paymentMethod, billingAddress) = request.ToDomain(); + var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, null) + .AndThenAsync(_ => updateBillingAddressCommand.Run(organization, billingAddress)) + .AndThenAsync(_ => restartSubscriptionCommand.Run(organization)); + return Handle(result); + } + [Authorize] [HttpGet("warnings")] [InjectOrganization] diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs index 544753ad0f..973a7d99a1 100644 --- a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs +++ b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs @@ -21,7 +21,7 @@ public class SelfHostedAccountBillingController( ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController { [HttpPost("license")] - [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)] [InjectUser] public async Task UploadLicenseAsync( [BindNever] User user, diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs new file mode 100644 index 0000000000..a3856bf173 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs @@ -0,0 +1,31 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionPlanChangeRequest : IValidatableObject +{ + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public ProductTierType Tier { get; set; } + + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public PlanCadenceType Cadence { get; set; } + + public OrganizationSubscriptionPlanChange ToDomain() => new() + { + Tier = Tier, + Cadence = Cadence + }; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Tier == ProductTierType.Families && Cadence == PlanCadenceType.Monthly) + { + yield return new ValidationResult("Monthly billing cadence is not available for the Families plan."); + } + } +} diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs new file mode 100644 index 0000000000..c678b1966c --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs @@ -0,0 +1,84 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionPurchaseRequest : IValidatableObject +{ + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public ProductTierType Tier { get; set; } + + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public PlanCadenceType Cadence { get; set; } + + [Required] + public required PasswordManagerPurchaseSelections PasswordManager { get; set; } + + public SecretsManagerPurchaseSelections? SecretsManager { get; set; } + + public OrganizationSubscriptionPurchase ToDomain() => new() + { + Tier = Tier, + Cadence = Cadence, + PasswordManager = new OrganizationSubscriptionPurchase.PasswordManagerSelections + { + Seats = PasswordManager.Seats, + AdditionalStorage = PasswordManager.AdditionalStorage, + Sponsored = PasswordManager.Sponsored + }, + SecretsManager = SecretsManager != null ? new OrganizationSubscriptionPurchase.SecretsManagerSelections + { + Seats = SecretsManager.Seats, + AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts, + Standalone = SecretsManager.Standalone + } : null + }; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Tier != ProductTierType.Families) + { + yield break; + } + + if (Cadence == PlanCadenceType.Monthly) + { + yield return new ValidationResult("Monthly cadence is not available on the Families plan."); + } + + if (SecretsManager != null) + { + yield return new ValidationResult("Secrets Manager is not available on the Families plan."); + } + } + + public record PasswordManagerPurchaseSelections + { + [Required] + [Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")] + public int Seats { get; set; } + + [Required] + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")] + public int AdditionalStorage { get; set; } + + public bool Sponsored { get; set; } = false; + } + + public record SecretsManagerPurchaseSelections + { + [Required] + [Range(1, 100000, ErrorMessage = "Secrets Manager seats must be between 1 and 100,000")] + public int Seats { get; set; } + + [Required] + [Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")] + public int AdditionalServiceAccounts { get; set; } + + public bool Standalone { get; set; } = false; + } +} diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs new file mode 100644 index 0000000000..ad5c3bd609 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs @@ -0,0 +1,48 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionUpdateRequest +{ + public PasswordManagerUpdateSelections? PasswordManager { get; set; } + public SecretsManagerUpdateSelections? SecretsManager { get; set; } + + public OrganizationSubscriptionUpdate ToDomain() => new() + { + PasswordManager = + PasswordManager != null + ? new OrganizationSubscriptionUpdate.PasswordManagerSelections + { + Seats = PasswordManager.Seats, + AdditionalStorage = PasswordManager.AdditionalStorage + } + : null, + SecretsManager = + SecretsManager != null + ? new OrganizationSubscriptionUpdate.SecretsManagerSelections + { + Seats = SecretsManager.Seats, + AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts + } + : null + }; + + public record PasswordManagerUpdateSelections + { + [Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")] + public int? Seats { get; set; } + + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")] + public int? AdditionalStorage { get; set; } + } + + public record SecretsManagerUpdateSelections + { + [Range(0, 100000, ErrorMessage = "Secrets Manager seats must be between 0 and 100,000")] + public int? Seats { get; set; } + + [Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")] + public int? AdditionalServiceAccounts { get; set; } + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs index 5c3c47f585..0426a51f10 100644 --- a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs index bb6e7498d7..ec1405c566 100644 --- a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs index 54116e897d..ccf2b30b50 100644 --- a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs index b4d28017d5..29c10e6631 100644 --- a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs index 3b50d2bf63..b0e415c262 100644 --- a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Api.Billing.Attributes; using Bit.Core.Billing.Payment.Models; @@ -14,12 +13,9 @@ public class MinimalTokenizedPaymentMethodRequest [Required] public required string Token { get; set; } - public TokenizedPaymentMethod ToDomain() + public TokenizedPaymentMethod ToDomain() => new() { - return new TokenizedPaymentMethod - { - Type = TokenizablePaymentMethodTypeExtensions.From(Type), - Token = Token - }; - } + Type = TokenizablePaymentMethodTypeExtensions.From(Type), + Token = Token + }; } diff --git a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs index f540957a1a..2a54313421 100644 --- a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs @@ -1,31 +1,15 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; -using Bit.Api.Billing.Attributes; -using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; -public class TokenizedPaymentMethodRequest +public class TokenizedPaymentMethodRequest : MinimalTokenizedPaymentMethodRequest { - [Required] - [PaymentMethodTypeValidation] - public required string Type { get; set; } - - [Required] - public required string Token { get; set; } - public MinimalBillingAddressRequest? BillingAddress { get; set; } - public (TokenizedPaymentMethod, BillingAddress?) ToDomain() + public new (TokenizedPaymentMethod, BillingAddress?) ToDomain() { - var paymentMethod = new TokenizedPaymentMethod - { - Type = TokenizablePaymentMethodTypeExtensions.From(Type), - Token = Token - }; - + var paymentMethod = base.ToDomain(); var billingAddress = BillingAddress?.ToDomain(); - return (paymentMethod, billingAddress); } } diff --git a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs index b958057f5b..03f20ec9c1 100644 --- a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Payment.Models; diff --git a/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs b/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs deleted file mode 100644 index a3fda0fd6c..0000000000 --- a/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs +++ /dev/null @@ -1,27 +0,0 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; -using Bit.Core.Billing.Enums; - -namespace Bit.Api.Billing.Models.Requests; - -public class PreviewTaxAmountForOrganizationTrialRequestBody -{ - [Required] - public PlanType PlanType { get; set; } - - [Required] - public ProductType ProductType { get; set; } - - [Required] public TaxInformationDTO TaxInformation { get; set; } = null!; - - public class TaxInformationDTO - { - [Required] - public string Country { get; set; } = null!; - - [Required] - public string PostalCode { get; set; } = null!; - - public string? TaxId { get; set; } - } -} diff --git a/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs new file mode 100644 index 0000000000..ac66270427 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs @@ -0,0 +1,16 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Subscriptions; + +public class RestartSubscriptionRequest +{ + [Required] + public required MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; } + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (TokenizedPaymentMethod, BillingAddress) ToDomain() + => (PaymentMethod.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs new file mode 100644 index 0000000000..9233a53c85 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewOrganizationSubscriptionPlanChangeTaxRequest +{ + [Required] + public required OrganizationSubscriptionPlanChangeRequest Plan { get; set; } + + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (OrganizationSubscriptionPlanChange, BillingAddress) ToDomain() => + (Plan.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs new file mode 100644 index 0000000000..dcc5911f3d --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewOrganizationSubscriptionPurchaseTaxRequest +{ + [Required] + public required OrganizationSubscriptionPurchaseRequest Purchase { get; set; } + + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (OrganizationSubscriptionPurchase, BillingAddress) ToDomain() => + (Purchase.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs new file mode 100644 index 0000000000..ae96214ae3 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs @@ -0,0 +1,11 @@ +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public class PreviewOrganizationSubscriptionUpdateTaxRequest +{ + public required OrganizationSubscriptionUpdateRequest Update { get; set; } + + public OrganizationSubscriptionUpdate ToDomain() => Update.ToDomain(); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs new file mode 100644 index 0000000000..76b8a5a444 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs @@ -0,0 +1,17 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewPremiumSubscriptionPurchaseTaxRequest +{ + [Required] + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB.")] + public short AdditionalStorage { get; set; } + + [Required] + public required MinimalBillingAddressRequest BillingAddress { get; set; } + + public (short, BillingAddress) ToDomain() => (AdditionalStorage, BillingAddress.ToDomain()); +} diff --git a/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs b/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs index cd65a7cdf8..ad5d5e092b 100644 --- a/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs +++ b/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs @@ -29,6 +29,7 @@ public class AccessPoliciesController : Controller private readonly IServiceAccountRepository _serviceAccountRepository; private readonly IUpdateServiceAccountGrantedPoliciesCommand _updateServiceAccountGrantedPoliciesCommand; private readonly IUserService _userService; + private readonly IEventService _eventService; private readonly IProjectServiceAccountsAccessPoliciesUpdatesQuery _projectServiceAccountsAccessPoliciesUpdatesQuery; private readonly IUpdateProjectServiceAccountsAccessPoliciesCommand @@ -47,7 +48,8 @@ public class AccessPoliciesController : Controller IServiceAccountGrantedPolicyUpdatesQuery serviceAccountGrantedPolicyUpdatesQuery, IProjectServiceAccountsAccessPoliciesUpdatesQuery projectServiceAccountsAccessPoliciesUpdatesQuery, IUpdateServiceAccountGrantedPoliciesCommand updateServiceAccountGrantedPoliciesCommand, - IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand) + IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand, + IEventService eventService) { _authorizationService = authorizationService; _userService = userService; @@ -61,6 +63,7 @@ public class AccessPoliciesController : Controller _serviceAccountGrantedPolicyUpdatesQuery = serviceAccountGrantedPolicyUpdatesQuery; _projectServiceAccountsAccessPoliciesUpdatesQuery = projectServiceAccountsAccessPoliciesUpdatesQuery; _updateProjectServiceAccountsAccessPoliciesCommand = updateProjectServiceAccountsAccessPoliciesCommand; + _eventService = eventService; } [HttpGet("/organizations/{id}/access-policies/people/potential-grantees")] @@ -186,7 +189,9 @@ public class AccessPoliciesController : Controller } var userId = _userService.GetProperUserId(User)!.Value; + var currentPolicies = await _accessPolicyRepository.GetPeoplePoliciesByGrantedServiceAccountIdAsync(peopleAccessPolicies.Id, userId); var results = await _accessPolicyRepository.ReplaceServiceAccountPeopleAsync(peopleAccessPolicies, userId); + await LogAccessPolicyServiceAccountChanges(currentPolicies, results, userId); return new ServiceAccountPeopleAccessPoliciesResponseModel(results, userId); } @@ -336,4 +341,39 @@ public class AccessPoliciesController : Controller userId, accessClient); return new ServiceAccountGrantedPoliciesPermissionDetailsResponseModel(results); } + + public async Task LogAccessPolicyServiceAccountChanges(IEnumerable currentPolicies, IEnumerable updatedPolicies, Guid userId) + { + foreach (var current in currentPolicies.OfType()) + { + if (!updatedPolicies.Any(r => r.Id == current.Id)) + { + await _eventService.LogServiceAccountGroupEventAsync(userId, current, EventType.ServiceAccount_GroupRemoved, _currentContext.IdentityClientType); + } + } + + foreach (var policy in updatedPolicies.OfType()) + { + if (!currentPolicies.Any(e => e.Id == policy.Id)) + { + await _eventService.LogServiceAccountGroupEventAsync(userId, policy, EventType.ServiceAccount_GroupAdded, _currentContext.IdentityClientType); + } + } + + foreach (var current in currentPolicies.OfType()) + { + if (!updatedPolicies.Any(r => r.Id == current.Id)) + { + await _eventService.LogServiceAccountPeopleEventAsync(userId, current, EventType.ServiceAccount_UserRemoved, _currentContext.IdentityClientType); + } + } + + foreach (var policy in updatedPolicies.OfType()) + { + if (!currentPolicies.Any(e => e.Id == policy.Id)) + { + await _eventService.LogServiceAccountPeopleEventAsync(userId, policy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType); + } + } + } } diff --git a/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs b/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs index 499c496cc9..0afdc3a1bf 100644 --- a/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs +++ b/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs @@ -42,6 +42,8 @@ public class ServiceAccountsController : Controller private readonly IDeleteServiceAccountsCommand _deleteServiceAccountsCommand; private readonly IRevokeAccessTokensCommand _revokeAccessTokensCommand; private readonly IPricingClient _pricingClient; + private readonly IEventService _eventService; + private readonly IOrganizationUserRepository _organizationUserRepository; public ServiceAccountsController( ICurrentContext currentContext, @@ -58,7 +60,9 @@ public class ServiceAccountsController : Controller IUpdateServiceAccountCommand updateServiceAccountCommand, IDeleteServiceAccountsCommand deleteServiceAccountsCommand, IRevokeAccessTokensCommand revokeAccessTokensCommand, - IPricingClient pricingClient) + IPricingClient pricingClient, + IEventService eventService, + IOrganizationUserRepository organizationUserRepository) { _currentContext = currentContext; _userService = userService; @@ -75,6 +79,8 @@ public class ServiceAccountsController : Controller _pricingClient = pricingClient; _createAccessTokenCommand = createAccessTokenCommand; _updateSecretsManagerSubscriptionCommand = updateSecretsManagerSubscriptionCommand; + _eventService = eventService; + _organizationUserRepository = organizationUserRepository; } [HttpGet("/organizations/{organizationId}/service-accounts")] @@ -139,8 +145,15 @@ public class ServiceAccountsController : Controller } var userId = _userService.GetProperUserId(User).Value; + var result = - await _createServiceAccountCommand.CreateAsync(createRequest.ToServiceAccount(organizationId), userId); + await _createServiceAccountCommand.CreateAsync(serviceAccount, userId); + + if (result != null) + { + await _eventService.LogServiceAccountEventAsync(userId, [serviceAccount], EventType.ServiceAccount_Created, _currentContext.IdentityClientType); + } + return new ServiceAccountResponseModel(result); } @@ -197,6 +210,9 @@ public class ServiceAccountsController : Controller } await _deleteServiceAccountsCommand.DeleteServiceAccounts(serviceAccountsToDelete); + var userId = _userService.GetProperUserId(User)!.Value; + await _eventService.LogServiceAccountEventAsync(userId, serviceAccountsToDelete, EventType.ServiceAccount_Deleted, _currentContext.IdentityClientType); + var responses = results.Select(r => new BulkDeleteResponseModel(r.ServiceAccount.Id, r.Error)); return new ListResponseModel(responses); } diff --git a/src/Core/AdminConsole/Entities/Event.cs b/src/Core/AdminConsole/Entities/Event.cs index 38d8f07b53..e2868c1915 100644 --- a/src/Core/AdminConsole/Entities/Event.cs +++ b/src/Core/AdminConsole/Entities/Event.cs @@ -34,6 +34,7 @@ public class Event : ITableObject, IEvent SecretId = e.SecretId; ProjectId = e.ProjectId; ServiceAccountId = e.ServiceAccountId; + GrantedServiceAccountId = e.GrantedServiceAccountId; } public Guid Id { get; set; } @@ -59,7 +60,7 @@ public class Event : ITableObject, IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } - + public Guid? GrantedServiceAccountId { get; set; } public void SetNewId() { Id = CoreHelpers.GenerateComb(); diff --git a/src/Core/AdminConsole/Enums/EventType.cs b/src/Core/AdminConsole/Enums/EventType.cs index 81501fd6ec..8073938fc5 100644 --- a/src/Core/AdminConsole/Enums/EventType.cs +++ b/src/Core/AdminConsole/Enums/EventType.cs @@ -70,8 +70,8 @@ public enum EventType : int Organization_EnabledKeyConnector = 1606, Organization_DisabledKeyConnector = 1607, Organization_SponsorshipsSynced = 1608, - [Obsolete("Use other specific Organization_CollectionManagement events instead")] - Organization_CollectionManagement_Updated = 1609, // TODO: Will be removed in PM-25315 + [Obsolete("Kept for historical data. Use specific Organization_CollectionManagement events instead.")] + Organization_CollectionManagement_Updated = 1609, Organization_CollectionManagement_LimitCollectionCreationEnabled = 1610, Organization_CollectionManagement_LimitCollectionCreationDisabled = 1611, Organization_CollectionManagement_LimitCollectionDeletionEnabled = 1612, @@ -109,4 +109,11 @@ public enum EventType : int Project_Created = 2201, Project_Edited = 2202, Project_Deleted = 2203, + + ServiceAccount_UserAdded = 2300, + ServiceAccount_UserRemoved = 2301, + ServiceAccount_GroupAdded = 2302, + ServiceAccount_GroupRemoved = 2303, + ServiceAccount_Created = 2304, + ServiceAccount_Deleted = 2305, } diff --git a/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs b/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs new file mode 100644 index 0000000000..78a7bc6d63 --- /dev/null +++ b/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs @@ -0,0 +1,10 @@ +namespace Bit.Api.AdminConsole.Models.Response.Organizations; + +public enum OrganizationIntegrationStatus : int +{ + NotApplicable, + Invalid, + Initiated, + InProgress, + Completed +} diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 452fbcce01..3ac14d67f3 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -20,6 +20,7 @@ public enum PolicyType : byte RestrictedItemTypesPolicy = 15, UriMatchDefaults = 16, AutotypeDefaultSetting = 17, + AutomaticUserConfirmation = 18, } public static class PolicyTypeExtensions @@ -50,6 +51,7 @@ public static class PolicyTypeExtensions PolicyType.RestrictedItemTypesPolicy => "Restricted item types", PolicyType.UriMatchDefaults => "URI match defaults", PolicyType.AutotypeDefaultSetting => "Autotype default setting", + PolicyType.AutomaticUserConfirmation => "Automatically confirm invited users", }; } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs index 7b2dd1343e..7df1459941 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs @@ -5,4 +5,6 @@ public interface IEventListenerConfiguration public string EventQueueName { get; } public string EventSubscriptionName { get; } public string EventTopicName { get; } + public int EventPrefetchCount { get; } + public int EventMaxConcurrentCalls { get; } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs index 322a1cd952..30401bb072 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs @@ -10,6 +10,8 @@ public interface IIntegrationListenerConfiguration : IEventListenerConfiguration public string IntegrationSubscriptionName { get; } public string IntegrationTopicName { get; } public int MaxRetries { get; } + public int IntegrationPrefetchCount { get; } + public int IntegrationMaxConcurrentCalls { get; } public string RoutingKey { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs new file mode 100644 index 0000000000..3b29bbebb4 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs @@ -0,0 +1,71 @@ +using System.Security.Cryptography; +using System.Text; +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; + +public class IntegrationOAuthState +{ + private const int _orgHashLength = 12; + private static readonly TimeSpan _maxAge = TimeSpan.FromMinutes(20); + + public Guid IntegrationId { get; } + private DateTimeOffset Issued { get; } + private string OrganizationIdHash { get; } + + private IntegrationOAuthState(Guid integrationId, string organizationIdHash, DateTimeOffset issued) + { + IntegrationId = integrationId; + OrganizationIdHash = organizationIdHash; + Issued = issued; + } + + public static IntegrationOAuthState FromIntegration(OrganizationIntegration integration, TimeProvider timeProvider) + { + var integrationId = integration.Id; + var issuedUtc = timeProvider.GetUtcNow(); + var organizationIdHash = ComputeOrgHash(integration.OrganizationId, issuedUtc.ToUnixTimeSeconds()); + + return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc); + } + + public static IntegrationOAuthState? FromString(string state, TimeProvider timeProvider) + { + if (string.IsNullOrWhiteSpace(state)) return null; + + var parts = state.Split('.'); + if (parts.Length != 3) return null; + + // Verify timestamp + if (!long.TryParse(parts[2], out var unixSeconds)) return null; + + var issuedUtc = DateTimeOffset.FromUnixTimeSeconds(unixSeconds); + var now = timeProvider.GetUtcNow(); + var age = now - issuedUtc; + + if (age > _maxAge) return null; + + // Parse integration id and store org + if (!Guid.TryParse(parts[0], out var integrationId)) return null; + var organizationIdHash = parts[1]; + + return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc); + } + + public bool ValidateOrg(Guid orgId) + { + var expected = ComputeOrgHash(orgId, Issued.ToUnixTimeSeconds()); + return expected == OrganizationIdHash; + } + + public override string ToString() + { + return $"{IntegrationId}.{OrganizationIdHash}.{Issued.ToUnixTimeSeconds()}"; + } + + private static string ComputeOrgHash(Guid orgId, long timestamp) + { + var bytes = SHA256.HashData(Encoding.UTF8.GetBytes($"{orgId:N}:{timestamp}")); + return Convert.ToHexString(bytes)[.._orgHashLength]; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs index 662bb8241e..40eb2b3e77 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs @@ -25,4 +25,24 @@ public abstract class ListenerConfiguration { get => _globalSettings.EventLogging.AzureServiceBus.IntegrationTopicName; } + + public int EventPrefetchCount + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount; + } + + public int EventMaxConcurrentCalls + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls; + } + + public int IntegrationPrefetchCount + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount; + } + + public int IntegrationMaxConcurrentCalls + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls; + } } diff --git a/src/Core/AdminConsole/Models/Data/EventMessage.cs b/src/Core/AdminConsole/Models/Data/EventMessage.cs index b708c5bd56..a29d70c203 100644 --- a/src/Core/AdminConsole/Models/Data/EventMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventMessage.cs @@ -39,4 +39,5 @@ public class EventMessage : IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/EventTableEntity.cs b/src/Core/AdminConsole/Models/Data/EventTableEntity.cs index 4ba50aee0d..1c3023f2cf 100644 --- a/src/Core/AdminConsole/Models/Data/EventTableEntity.cs +++ b/src/Core/AdminConsole/Models/Data/EventTableEntity.cs @@ -37,6 +37,7 @@ public class AzureEvent : ITableEntity public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } public EventTableEntity ToEventTableEntity() { @@ -68,6 +69,7 @@ public class AzureEvent : ITableEntity SecretId = SecretId, ServiceAccountId = ServiceAccountId, ProjectId = ProjectId, + GrantedServiceAccountId = GrantedServiceAccountId }; } } @@ -99,6 +101,7 @@ public class EventTableEntity : IEvent SecretId = e.SecretId; ProjectId = e.ProjectId; ServiceAccountId = e.ServiceAccountId; + GrantedServiceAccountId = e.GrantedServiceAccountId; } public string PartitionKey { get; set; } @@ -127,6 +130,7 @@ public class EventTableEntity : IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } public AzureEvent ToAzureEvent() { @@ -157,7 +161,8 @@ public class EventTableEntity : IEvent DomainName = DomainName, SecretId = SecretId, ProjectId = ProjectId, - ServiceAccountId = ServiceAccountId + ServiceAccountId = ServiceAccountId, + GrantedServiceAccountId = GrantedServiceAccountId }; } @@ -232,6 +237,15 @@ public class EventTableEntity : IEvent }); } + if (e.GrantedServiceAccountId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"GrantedServiceAccountId={e.GrantedServiceAccountId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + return entities; } diff --git a/src/Core/AdminConsole/Models/Data/IEvent.cs b/src/Core/AdminConsole/Models/Data/IEvent.cs index 750fb2e2eb..3188c905e4 100644 --- a/src/Core/AdminConsole/Models/Data/IEvent.cs +++ b/src/Core/AdminConsole/Models/Data/IEvent.cs @@ -28,4 +28,5 @@ public interface IEvent Guid? SecretId { get; set; } Guid? ProjectId { get; set; } Guid? ServiceAccountId { get; set; } + Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Core/AdminConsole/Repositories/IEventRepository.cs b/src/Core/AdminConsole/Repositories/IEventRepository.cs index 281d6ec8c7..f0c185561b 100644 --- a/src/Core/AdminConsole/Repositories/IEventRepository.cs +++ b/src/Core/AdminConsole/Repositories/IEventRepository.cs @@ -27,6 +27,7 @@ public interface IEventRepository DateTime startDate, DateTime endDate, PageOptions pageOptions); Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task CreateAsync(IEvent e); Task CreateManyAsync(IEnumerable e); Task> GetManyByOrganizationServiceAccountAsync(Guid organizationId, Guid serviceAccountId, diff --git a/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs b/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs index c9c803b5b2..169b36bf69 100644 --- a/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs @@ -77,12 +77,18 @@ public class EventRepository : IEventRepository return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); } - public async Task> GetManyByOrganizationServiceAccountAsync(Guid organizationId, - Guid serviceAccountId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + public async Task> GetManyByOrganizationServiceAccountAsync( + Guid organizationId, + Guid serviceAccountId, + DateTime startDate, + DateTime endDate, + PageOptions pageOptions) { + return await GetManyServiceAccountAsync( + $"OrganizationId={organizationId}", + serviceAccountId.ToString(), + startDate, endDate, pageOptions); - return await GetManyAsync($"OrganizationId={organizationId}", - $"ServiceAccountId={serviceAccountId}__Date={{0}}", startDate, endDate, pageOptions); } public async Task CreateAsync(IEvent e) @@ -141,6 +147,40 @@ public class EventRepository : IEventRepository } } + public async Task> GetManyServiceAccountAsync( + string partitionKey, + string serviceAccountId, + DateTime startDate, + DateTime endDate, + PageOptions pageOptions) + { + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + var filter = MakeFilterForServiceAccount(partitionKey, serviceAccountId, startDate, endDate); + + var result = new PagedResult(); + var query = _tableClient.QueryAsync(filter, pageOptions.PageSize); + + await using (var enumerator = query.AsPages(pageOptions.ContinuationToken, + pageOptions.PageSize).GetAsyncEnumerator()) + { + if (await enumerator.MoveNextAsync()) + { + result.ContinuationToken = enumerator.Current.ContinuationToken; + + var events = enumerator.Current.Values + .Select(e => e.ToEventTableEntity()) + .ToList(); + + events = events.OrderByDescending(e => e.Date).ToList(); + + result.Data.AddRange(events); + } + } + + return result; + } + public async Task> GetManyAsync(string partitionKey, string rowKey, DateTime startDate, DateTime endDate, PageOptions pageOptions) { @@ -172,4 +212,27 @@ public class EventRepository : IEventRepository { return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'"; } + + private string MakeFilterForServiceAccount( + string partitionKey, + string machineAccountId, + DateTime startDate, + DateTime endDate) + { + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + + var rowKey1Start = $"ServiceAccountId={machineAccountId}__Date={start}"; + var rowKey1End = $"ServiceAccountId={machineAccountId}__Date={end}"; + + var rowKey2Start = $"GrantedServiceAccountId={machineAccountId}__Date={start}"; + var rowKey2End = $"GrantedServiceAccountId={machineAccountId}__Date={end}"; + + var left = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey1Start}' and RowKey ge '{rowKey1End}'"; + var right = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey2Start}' and RowKey ge '{rowKey2End}'"; + + return $"({left}) or ({right})"; + } + + } diff --git a/src/Core/AdminConsole/Services/IEventService.cs b/src/Core/AdminConsole/Services/IEventService.cs index 80e8e63d8c..795c06e254 100644 --- a/src/Core/AdminConsole/Services/IEventService.cs +++ b/src/Core/AdminConsole/Services/IEventService.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; +using Bit.Core.Auth.Identity; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; @@ -37,4 +38,7 @@ public interface IEventService Task LogServiceAccountSecretsEventAsync(Guid serviceAccountId, IEnumerable secrets, EventType type, DateTime? date = null); Task LogUserProjectsEventAsync(Guid userId, IEnumerable projects, EventType type, DateTime? date = null); Task LogServiceAccountProjectsEventAsync(Guid serviceAccountId, IEnumerable projects, EventType type, DateTime? date = null); + Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null); + Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null); + Task LogServiceAccountEventAsync(Guid userId, List serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null); } diff --git a/src/Core/AdminConsole/Services/IProviderService.cs b/src/Core/AdminConsole/Services/IProviderService.cs index 66c49d90c6..2b954346ae 100644 --- a/src/Core/AdminConsole/Services/IProviderService.cs +++ b/src/Core/AdminConsole/Services/IProviderService.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -11,8 +11,7 @@ namespace Bit.Core.AdminConsole.Services; public interface IProviderService { - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource = null); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress); Task UpdateAsync(Provider provider, bool updateBilling = false); Task> InviteUserAsync(ProviderUserInvite invite); diff --git a/src/Core/AdminConsole/Services/ISlackService.cs b/src/Core/AdminConsole/Services/ISlackService.cs index 6c6a846f0d..ff1e03f051 100644 --- a/src/Core/AdminConsole/Services/ISlackService.cs +++ b/src/Core/AdminConsole/Services/ISlackService.cs @@ -5,7 +5,7 @@ public interface ISlackService Task GetChannelIdAsync(string token, string channelName); Task> GetChannelIdsAsync(string token, List channelNames); Task GetDmChannelByEmailAsync(string token, string email); - string GetRedirectUrl(string redirectUrl); + string GetRedirectUrl(string callbackUrl, string state); Task ObtainTokenViaOAuth(string code, string redirectUrl); Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId); } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs index 91f8fac888..a589211687 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs @@ -14,13 +14,14 @@ public class AzureServiceBusEventListenerService : EventLoggingL TConfiguration configuration, IEventMessageHandler handler, IAzureServiceBusService serviceBusService, + ServiceBusProcessorOptions serviceBusOptions, ILoggerFactory loggerFactory) : base(handler, CreateLogger(loggerFactory, configuration)) { _processor = serviceBusService.CreateProcessor( topicName: configuration.EventTopicName, subscriptionName: configuration.EventSubscriptionName, - new ServiceBusProcessorOptions()); + options: serviceBusOptions); } protected override async Task ExecuteAsync(CancellationToken cancellationToken) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs index e415430965..633a53296b 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs @@ -18,6 +18,7 @@ public class AzureServiceBusIntegrationListenerService : Backgro TConfiguration configuration, IIntegrationHandler handler, IAzureServiceBusService serviceBusService, + ServiceBusProcessorOptions serviceBusOptions, ILoggerFactory loggerFactory) { _handler = handler; @@ -29,7 +30,7 @@ public class AzureServiceBusIntegrationListenerService : Backgro _processor = _serviceBusService.CreateProcessor( topicName: configuration.IntegrationTopicName, subscriptionName: configuration.IntegrationSubscriptionName, - options: new ServiceBusProcessorOptions()); + options: serviceBusOptions); } protected override async Task ExecuteAsync(CancellationToken cancellationToken) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs index f17185c4d3..4fb74f1f44 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs @@ -19,6 +19,7 @@ public class SlackService( private readonly string _slackApiBaseUrl = globalSettings.Slack.ApiBaseUrl; public const string HttpClientName = "SlackServiceHttpClient"; + private const string _slackOAuthBaseUri = "https://slack.com/oauth/v2/authorize"; public async Task GetChannelIdAsync(string token, string channelName) { @@ -73,9 +74,18 @@ public class SlackService( return await OpenDmChannel(token, userId); } - public string GetRedirectUrl(string redirectUrl) + public string GetRedirectUrl(string callbackUrl, string state) { - return $"https://slack.com/oauth/v2/authorize?client_id={_clientId}&scope={_scopes}&redirect_uri={redirectUrl}"; + var builder = new UriBuilder(_slackOAuthBaseUri); + var query = HttpUtility.ParseQueryString(builder.Query); + + query["client_id"] = _clientId; + query["scope"] = _scopes; + query["redirect_uri"] = callbackUrl; + query["state"] = state; + + builder.Query = query.ToString(); + return builder.ToString(); } public async Task ObtainTokenViaOAuth(string code, string redirectUrl) diff --git a/src/Core/AdminConsole/Services/Implementations/EventService.cs b/src/Core/AdminConsole/Services/Implementations/EventService.cs index e0e0e040f1..77d481890e 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventService.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -516,6 +517,135 @@ public class EventService : IEventService await _eventWriteService.CreateManyAsync(eventMessages); } + + public async Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + var orgUser = await _organizationUserRepository.GetByIdAsync((Guid)policy.OrganizationUserId); + + if (!CanUseEvents(orgAbilities, orgUser.OrganizationId)) + { + return; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + return; + } + + if (policy.OrganizationUserId != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = orgUser.OrganizationId, + Type = type, + GrantedServiceAccountId = policy.GrantedServiceAccountId, + ServiceAccountId = serviceAccountId, + UserId = policy.OrganizationUserId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + public async Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + + if (!CanUseEvents(orgAbilities, policy.Group.OrganizationId)) + { + return; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + return; + } + + if (policy.GroupId != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = policy.Group.OrganizationId, + Type = type, + GrantedServiceAccountId = policy.GrantedServiceAccountId, + ServiceAccountId = serviceAccountId, + GroupId = policy.GroupId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + public async Task LogServiceAccountEventAsync(Guid userId, List serviceAccounts, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + + foreach (var serviceAccount in serviceAccounts) + { + if (!CanUseEvents(orgAbilities, serviceAccount.OrganizationId)) + { + continue; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + continue; + } + + if (serviceAccount != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = serviceAccount.OrganizationId, + Type = type, + GrantedServiceAccountId = serviceAccount.Id, + ServiceAccountId = serviceAccountId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + } + } + + if (eventMessages.Any()) + { + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + private (Guid? actingUserId, Guid? serviceAccountId) MapIdentityClientType( + Guid userId, IdentityClientType identityClientType) + { + if (identityClientType == IdentityClientType.Organization) + { + return (null, null); + } + + return identityClientType switch + { + IdentityClientType.User => (userId, null), + IdentityClientType.ServiceAccount => (null, userId), + _ => throw new InvalidOperationException("Unknown identity client type.") + }; + } + + private async Task GetProviderIdAsync(Guid? orgId) { if (_currentContext == null || !orgId.HasValue) diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs index e8dd495205..6ecea7d234 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; +using Bit.Core.Auth.Identity; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; @@ -139,4 +140,19 @@ public class NoopEventService : IEventService { return Task.FromResult(0); } + + public Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } + + public Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } + + public Task LogServiceAccountEventAsync(Guid userId, List serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } } diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs index 2bf4a54a87..3782b30e3f 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -11,7 +11,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations; public class NoopProviderService : IProviderService { - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs index c34c073e87..d6c8d08c4c 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs @@ -19,7 +19,7 @@ public class NoopSlackService : ISlackService return Task.FromResult(string.Empty); } - public string GetRedirectUrl(string redirectUrl) + public string GetRedirectUrl(string callbackUrl, string state) { return string.Empty; } diff --git a/src/Core/Billing/Commands/BillingCommandResult.cs b/src/Core/Billing/Commands/BillingCommandResult.cs index 3238ab4107..db260e7038 100644 --- a/src/Core/Billing/Commands/BillingCommandResult.cs +++ b/src/Core/Billing/Commands/BillingCommandResult.cs @@ -1,5 +1,4 @@ -#nullable enable -using OneOf; +using OneOf; namespace Bit.Core.Billing.Commands; @@ -20,18 +19,38 @@ public record Unhandled(Exception? Exception = null, string Response = "Somethin /// /// /// The successful result type of the operation. -public class BillingCommandResult : OneOfBase +public class BillingCommandResult(OneOf input) + : OneOfBase(input) { - private BillingCommandResult(OneOf input) : base(input) { } - public static implicit operator BillingCommandResult(T output) => new(output); public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest); public static implicit operator BillingCommandResult(Conflict conflict) => new(conflict); public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled); + public BillingCommandResult Map(Func f) + => Match( + value => new BillingCommandResult(f(value)), + badRequest => new BillingCommandResult(badRequest), + conflict => new BillingCommandResult(conflict), + unhandled => new BillingCommandResult(unhandled)); + public Task TapAsync(Func f) => Match( f, _ => Task.CompletedTask, _ => Task.CompletedTask, _ => Task.CompletedTask); } + +public static class BillingCommandResultExtensions +{ + public static async Task> AndThenAsync( + this Task> task, Func>> binder) + { + var result = await task; + return await result.Match( + binder, + badRequest => Task.FromResult(new BillingCommandResult(badRequest)), + conflict => Task.FromResult(new BillingCommandResult(conflict)), + unhandled => Task.FromResult(new BillingCommandResult(unhandled))); + } +} diff --git a/src/Core/Billing/Enums/PlanCadenceType.cs b/src/Core/Billing/Enums/PlanCadenceType.cs new file mode 100644 index 0000000000..9e6fa69832 --- /dev/null +++ b/src/Core/Billing/Enums/PlanCadenceType.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Enums; + +public enum PlanCadenceType +{ + Annually, + Monthly +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index b4e37f0151..7aec422a4b 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -9,7 +9,7 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; -using Bit.Core.Billing.Tax.Commands; +using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services.Implementations; @@ -28,11 +28,12 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); - services.AddTransient(); services.AddPaymentOperations(); services.AddOrganizationLicenseCommandsQueries(); services.AddPremiumCommands(); services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services) @@ -46,5 +47,6 @@ public static class ServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs new file mode 100644 index 0000000000..041e9bdbad --- /dev/null +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -0,0 +1,383 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Enums; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using OneOf; +using Stripe; + +namespace Bit.Core.Billing.Organizations.Commands; + +using static Core.Constants; +using static StripeConstants; + +public interface IPreviewOrganizationTaxCommand +{ + Task> Run( + OrganizationSubscriptionPurchase purchase, + BillingAddress billingAddress); + + Task> Run( + Organization organization, + OrganizationSubscriptionPlanChange planChange, + BillingAddress billingAddress); + + Task> Run( + Organization organization, + OrganizationSubscriptionUpdate update); +} + +public class PreviewOrganizationTaxCommand( + ILogger logger, + IPricingClient pricingClient, + IStripeAdapter stripeAdapter) + : BaseBillingCommand(logger), IPreviewOrganizationTaxCommand +{ + public Task> Run( + OrganizationSubscriptionPurchase purchase, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + var plan = await pricingClient.GetPlanOrThrow(purchase.PlanType); + + var options = GetBaseOptions(billingAddress, purchase.Tier != ProductTierType.Families); + + var items = new List(); + + switch (purchase) + { + case { PasswordManager.Sponsored: true }: + var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise); + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = sponsoredPlan.StripePlanId, + Quantity = 1 + }); + break; + + case { SecretsManager.Standalone: true }: + items.AddRange([ + new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.PasswordManager.StripeSeatPlanId, + Quantity = purchase.PasswordManager.Seats + }, + new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeSeatPlanId, + Quantity = purchase.SecretsManager.Seats + } + ]); + options.Coupon = CouponIDs.SecretsManagerStandalone; + break; + + default: + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.HasNonSeatBasedPasswordManagerPlan() + ? plan.PasswordManager.StripePlanId + : plan.PasswordManager.StripeSeatPlanId, + Quantity = purchase.PasswordManager.Seats + }); + + if (purchase.PasswordManager.AdditionalStorage > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.PasswordManager.StripeStoragePlanId, + Quantity = purchase.PasswordManager.AdditionalStorage + }); + } + + if (purchase.SecretsManager is { Seats: > 0 }) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeSeatPlanId, + Quantity = purchase.SecretsManager.Seats + }); + + if (purchase.SecretsManager.AdditionalServiceAccounts > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeServiceAccountPlanId, + Quantity = purchase.SecretsManager.AdditionalServiceAccounts + }); + } + } + + break; + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + public Task> Run( + Organization organization, + OrganizationSubscriptionPlanChange planChange, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + if (organization.PlanType.GetProductTier() == ProductTierType.Free) + { + var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); + + var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType); + + var items = new List + { + new () + { + Price = newPlan.HasNonSeatBasedPasswordManagerPlan() + ? newPlan.PasswordManager.StripePlanId + : newPlan.PasswordManager.StripeSeatPlanId, + Quantity = 2 + } + }; + + if (organization.UseSecretsManager && planChange.Tier != ProductTierType.Families) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeSeatPlanId, + Quantity = 2 + }); + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + } + else + { + if (organization is not + { + GatewayCustomerId: not null, + GatewaySubscriptionId: not null + }) + { + return new BadRequest("Organization does not have a subscription."); + } + + var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); + + var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer"] }); + + if (subscription.Customer.Discount != null) + { + options.Coupon = subscription.Customer.Discount.Coupon.Id; + } + + var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); + var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType); + + var subscriptionItemsByPriceId = + subscription.Items.ToDictionary(subscriptionItem => subscriptionItem.Price.Id); + + var items = new List(); + + var passwordManagerSeats = subscriptionItemsByPriceId[ + currentPlan.HasNonSeatBasedPasswordManagerPlan() + ? currentPlan.PasswordManager.StripePlanId + : currentPlan.PasswordManager.StripeSeatPlanId]; + + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.HasNonSeatBasedPasswordManagerPlan() + ? newPlan.PasswordManager.StripePlanId + : newPlan.PasswordManager.StripeSeatPlanId, + Quantity = passwordManagerSeats.Quantity + }); + + var hasStorage = + subscriptionItemsByPriceId.TryGetValue(newPlan.PasswordManager.StripeStoragePlanId, + out var storage); + + if (hasStorage && storage != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.PasswordManager.StripeStoragePlanId, + Quantity = storage.Quantity + }); + } + + var hasSecretsManagerSeats = subscriptionItemsByPriceId.TryGetValue( + newPlan.SecretsManager.StripeSeatPlanId, + out var secretsManagerSeats); + + if (hasSecretsManagerSeats && secretsManagerSeats != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeSeatPlanId, + Quantity = secretsManagerSeats.Quantity + }); + + var hasServiceAccounts = + subscriptionItemsByPriceId.TryGetValue(newPlan.SecretsManager.StripeServiceAccountPlanId, + out var serviceAccounts); + + if (hasServiceAccounts && serviceAccounts != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeServiceAccountPlanId, + Quantity = serviceAccounts.Quantity + }); + } + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + } + }); + + public Task> Run( + Organization organization, + OrganizationSubscriptionUpdate update) + => HandleAsync<(decimal, decimal)>(async () => + { + if (organization is not + { + GatewayCustomerId: not null, + GatewaySubscriptionId: not null + }) + { + return new BadRequest("Organization does not have a subscription."); + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.tax_ids"] }); + + var options = GetBaseOptions(subscription.Customer, + organization.GetProductUsageType() == ProductUsageType.Business); + + if (subscription.Customer.Discount != null) + { + options.Coupon = subscription.Customer.Discount.Coupon.Id; + } + + var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + var items = new List(); + + if (update.PasswordManager?.Seats != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.HasNonSeatBasedPasswordManagerPlan() + ? currentPlan.PasswordManager.StripePlanId + : currentPlan.PasswordManager.StripeSeatPlanId, + Quantity = update.PasswordManager.Seats + }); + } + + if (update.PasswordManager?.AdditionalStorage is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.PasswordManager.StripeStoragePlanId, + Quantity = update.PasswordManager.AdditionalStorage + }); + } + + if (update.SecretsManager?.Seats is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.SecretsManager.StripeSeatPlanId, + Quantity = update.SecretsManager.Seats + }); + + if (update.SecretsManager.AdditionalServiceAccounts is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.SecretsManager.StripeServiceAccountPlanId, + Quantity = update.SecretsManager.AdditionalServiceAccounts + }); + } + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + private static (decimal, decimal) GetAmounts(Invoice invoice) => ( + Convert.ToDecimal(invoice.Tax) / 100, + Convert.ToDecimal(invoice.Total) / 100); + + private static InvoiceCreatePreviewOptions GetBaseOptions( + OneOf addressChoice, + bool businessUse) + { + var country = addressChoice.Match( + customer => customer.Address.Country, + billingAddress => billingAddress.Country + ); + + var postalCode = addressChoice.Match( + customer => customer.Address.PostalCode, + billingAddress => billingAddress.PostalCode); + + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, + Currency = "usd", + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions { Country = country, PostalCode = postalCode }, + TaxExempt = businessUse && country != CountryAbbreviations.UnitedStates + ? TaxExempt.Reverse + : TaxExempt.None + } + }; + + var taxId = addressChoice.Match( + customer => + { + var taxId = customer.TaxIds?.FirstOrDefault(); + return taxId != null ? new TaxID(taxId.Type, taxId.Value) : null; + }, + billingAddress => billingAddress.TaxId); + + if (taxId == null) + { + return options; + } + + options.CustomerDetails.TaxIds = + [ + new InvoiceCustomerDetailsTaxIdOptions { Type = taxId.Code, Value = taxId.Value } + ]; + + if (taxId.Code == TaxIdType.SpanishNIF) + { + options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions + { + Type = TaxIdType.EUVAT, + Value = $"ES{taxId.Value}" + }); + } + + return options; + } +} diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs new file mode 100644 index 0000000000..7781f91960 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs @@ -0,0 +1,23 @@ +using Bit.Core.Billing.Enums; + +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionPlanChange +{ + public ProductTierType Tier { get; init; } + public PlanCadenceType Cadence { get; init; } + + public PlanType PlanType => + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault + Tier switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => Cadence == PlanCadenceType.Monthly + ? PlanType.TeamsMonthly + : PlanType.TeamsAnnually, + ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly + ? PlanType.EnterpriseMonthly + : PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException("Cannot change an Organization subscription to a tier that isn't Families, Teams or Enterprise.") + }; +} diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs new file mode 100644 index 0000000000..6691d69848 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs @@ -0,0 +1,39 @@ +using Bit.Core.Billing.Enums; + +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionPurchase +{ + public ProductTierType Tier { get; init; } + public PlanCadenceType Cadence { get; init; } + public required PasswordManagerSelections PasswordManager { get; init; } + public SecretsManagerSelections? SecretsManager { get; init; } + + public PlanType PlanType => + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault + Tier switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => Cadence == PlanCadenceType.Monthly + ? PlanType.TeamsMonthly + : PlanType.TeamsAnnually, + ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly + ? PlanType.EnterpriseMonthly + : PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException("Cannot purchase an Organization subscription that isn't Families, Teams or Enterprise.") + }; + + public record PasswordManagerSelections + { + public int Seats { get; init; } + public int AdditionalStorage { get; init; } + public bool Sponsored { get; init; } + } + + public record SecretsManagerSelections + { + public int Seats { get; init; } + public int AdditionalServiceAccounts { get; init; } + public bool Standalone { get; init; } + } +} diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs new file mode 100644 index 0000000000..810f292c81 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs @@ -0,0 +1,19 @@ +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionUpdate +{ + public PasswordManagerSelections? PasswordManager { get; init; } + public SecretsManagerSelections? SecretsManager { get; init; } + + public record PasswordManagerSelections + { + public int? Seats { get; init; } + public int? AdditionalStorage { get; init; } + } + + public record SecretsManagerSelections + { + public int? Seats { get; init; } + public int? AdditionalServiceAccounts { get; init; } + } +} diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs new file mode 100644 index 0000000000..a0b4fcabc2 --- /dev/null +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs @@ -0,0 +1,65 @@ +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Premium.Commands; + +using static StripeConstants; + +public interface IPreviewPremiumTaxCommand +{ + Task> Run( + int additionalStorage, + BillingAddress billingAddress); +} + +public class PreviewPremiumTaxCommand( + ILogger logger, + IStripeAdapter stripeAdapter) : BaseBillingCommand(logger), IPreviewPremiumTaxCommand +{ + public Task> Run( + int additionalStorage, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }, + Currency = "usd", + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = + [ + new InvoiceSubscriptionDetailsItemOptions { Price = Prices.PremiumAnnually, Quantity = 1 } + ] + } + }; + + if (additionalStorage > 0) + { + options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = Prices.StoragePlanPersonal, + Quantity = additionalStorage + }); + } + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + private static (decimal, decimal) GetAmounts(Invoice invoice) => ( + Convert.ToDecimal(invoice.Tax) / 100, + Convert.ToDecimal(invoice.Total) / 100); +} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs index 07a057d40c..e155b427f1 100644 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs +++ b/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs @@ -258,7 +258,7 @@ public class ProviderMigrator( // Create dummy payment source for legacy migration - this migrator is deprecated and will be removed var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token"); - var customer = await providerBillingService.SetupCustomer(provider, taxInfo, dummyPaymentSource); + var customer = await providerBillingService.SetupCustomer(provider, null, null); await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { diff --git a/src/Core/Billing/Providers/Services/IProviderBillingService.cs b/src/Core/Billing/Providers/Services/IProviderBillingService.cs index 173249f79f..57d68db038 100644 --- a/src/Core/Billing/Providers/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Providers/Services/IProviderBillingService.cs @@ -5,10 +5,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Models.Business; using Stripe; namespace Bit.Core.Billing.Providers.Services; @@ -79,16 +79,16 @@ public interface IProviderBillingService int seatAdjustment); /// - /// For use during the provider setup process, this method creates a Stripe for the specified utilizing the provided . + /// For use during the provider setup process, this method creates a Stripe for the specified utilizing the provided and . /// /// The to create a Stripe customer for. - /// The to use for calculating the customer's automatic tax. - /// The (ex. Credit Card) to attach to the customer. + /// The (e.g., Credit Card, Bank Account, or PayPal) to attach to the customer. + /// The containing the customer's billing information including address and tax ID details. /// The newly created for the . Task SetupCustomer( Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource); + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress); /// /// For use during the provider setup process, this method starts a Stripe for the given . diff --git a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs new file mode 100644 index 0000000000..351c75ace0 --- /dev/null +++ b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs @@ -0,0 +1,92 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using OneOf.Types; +using Stripe; + +namespace Bit.Core.Billing.Subscriptions.Commands; + +using static StripeConstants; + +public interface IRestartSubscriptionCommand +{ + Task> Run( + ISubscriber subscriber); +} + +public class RestartSubscriptionCommand( + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService, + IUserRepository userRepository) : IRestartSubscriptionCommand +{ + public async Task> Run( + ISubscriber subscriber) + { + var existingSubscription = await subscriberService.GetSubscription(subscriber); + + if (existingSubscription is not { Status: SubscriptionStatus.Canceled }) + { + return new BadRequest("Cannot restart a subscription that is not canceled."); + } + + var options = new SubscriptionCreateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, + CollectionMethod = CollectionMethod.ChargeAutomatically, + Customer = existingSubscription.CustomerId, + Items = existingSubscription.Items.Select(subscriptionItem => new SubscriptionItemOptions + { + Price = subscriptionItem.Price.Id, + Quantity = subscriptionItem.Quantity + }).ToList(), + Metadata = existingSubscription.Metadata, + OffSession = true, + TrialPeriodDays = 0 + }; + + var subscription = await stripeAdapter.SubscriptionCreateAsync(options); + await EnableAsync(subscriber, subscription); + return new None(); + } + + private async Task EnableAsync(ISubscriber subscriber, Subscription subscription) + { + switch (subscriber) + { + case Organization organization: + { + organization.GatewaySubscriptionId = subscription.Id; + organization.Enabled = true; + organization.ExpirationDate = subscription.CurrentPeriodEnd; + organization.RevisionDate = DateTime.UtcNow; + await organizationRepository.ReplaceAsync(organization); + break; + } + case Provider provider: + { + provider.GatewaySubscriptionId = subscription.Id; + provider.Enabled = true; + provider.RevisionDate = DateTime.UtcNow; + await providerRepository.ReplaceAsync(provider); + break; + } + case User user: + { + user.GatewaySubscriptionId = subscription.Id; + user.Premium = true; + user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + user.RevisionDate = DateTime.UtcNow; + await userRepository.ReplaceAsync(user); + break; + } + } + } +} diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs deleted file mode 100644 index 94d3724d73..0000000000 --- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs +++ /dev/null @@ -1,136 +0,0 @@ -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -namespace Bit.Core.Billing.Tax.Commands; - -public interface IPreviewTaxAmountCommand -{ - Task> Run(OrganizationTrialParameters parameters); -} - -public class PreviewTaxAmountCommand( - ILogger logger, - IPricingClient pricingClient, - IStripeAdapter stripeAdapter, - ITaxService taxService) : BaseBillingCommand(logger), IPreviewTaxAmountCommand -{ - protected override Conflict DefaultConflict - => new("We had a problem calculating your tax obligation. Please contact support for assistance."); - - public Task> Run(OrganizationTrialParameters parameters) - => HandleAsync(async () => - { - var (planType, productType, taxInformation) = parameters; - - var plan = await pricingClient.GetPlanOrThrow(planType); - - var options = new InvoiceCreatePreviewOptions - { - Currency = "usd", - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - Country = taxInformation.Country, - PostalCode = taxInformation.PostalCode - } - }, - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new InvoiceSubscriptionDetailsItemOptions - { - Price = plan.HasNonSeatBasedPasswordManagerPlan() - ? plan.PasswordManager.StripePlanId - : plan.PasswordManager.StripeSeatPlanId, - Quantity = 1 - } - ] - } - }; - - if (productType == ProductType.SecretsManager) - { - options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions - { - Price = plan.SecretsManager.StripeSeatPlanId, - Quantity = 1 - }); - - options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone; - } - - if (!string.IsNullOrEmpty(taxInformation.TaxId)) - { - var taxIdType = taxService.GetStripeTaxCode( - taxInformation.Country, - taxInformation.TaxId); - - if (string.IsNullOrEmpty(taxIdType)) - { - return new BadRequest( - "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance."); - } - - options.CustomerDetails.TaxIds = - [ - new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - if (parameters.PlanType.IsBusinessProductTierType() && - parameters.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) - { - options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); - return Convert.ToDecimal(invoice.Tax) / 100; - }); -} - -#region Command Parameters - -public record OrganizationTrialParameters -{ - public required PlanType PlanType { get; set; } - public required ProductType ProductType { get; set; } - public required TaxInformationDTO TaxInformation { get; set; } - - public void Deconstruct( - out PlanType planType, - out ProductType productType, - out TaxInformationDTO taxInformation) - { - planType = PlanType; - productType = ProductType; - taxInformation = TaxInformation; - } - - public record TaxInformationDTO - { - public required string Country { get; set; } - public required string PostalCode { get; set; } - public string? TaxId { get; set; } - } -} - -#endregion diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index d26e0f67fa..1574c7f2ce 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -70,6 +70,17 @@ public static class Constants /// public const string UnitedStates = "US"; } + + + /// + /// Constants for our browser extensions IDs + /// + public static class BrowserExtensions + { + public const string ChromeId = "chrome-extension://nngceckbapebfimnlniiiahkandclblb/"; + public const string EdgeId = "chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh/"; + public const string OperaId = "chrome-extension://ccnckbpmaceehanjmeomladnmlffdjgn/"; + } } public static class AuthConstants @@ -124,8 +135,6 @@ public static class AuthenticationSchemes public static class FeatureFlagKeys { /* Admin Console Team */ - public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; - public const string LimitItemDeletion = "pm-15493-restrict-item-deletion-to-can-manage-permission"; public const string PolicyRequirements = "pm-14439-policy-requirements"; public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast"; public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations"; @@ -169,10 +178,11 @@ public static class FeatureFlagKeys public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships"; public const string UsePricingService = "use-pricing-service"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; - public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout"; public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; - public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow"; + public const string PM24996ImplementUpgradeFromFreeDialog = "pm-24996-implement-upgrade-from-free-dialog"; + public const string PM24032_NewNavigationPremiumUpgradeButton = "pm-24032-new-navigation-premium-upgrade-button"; + public const string PM23713_PremiumBadgeOpensNewPremiumUpgradeDialog = "pm-23713-premium-badge-opens-new-premium-upgrade-dialog"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; @@ -222,7 +232,6 @@ public static class FeatureFlagKeys /* Vault Team */ public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge"; public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; - public const string SecurityTasks = "security-tasks"; public const string CipherKeyEncryption = "cipher-key-encryption"; public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms"; public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk"; diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index e76af0f8ef..e9bf1b1807 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,8 +21,8 @@ - - + + @@ -34,7 +34,7 @@ - + diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs index f7d6f0e5a2..739dca5228 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs @@ -315,7 +315,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs throw new BadRequestException($"Cannot set max Secrets Manager seat autoscaling below current Secrets Manager seat count."); } - if (plan.SecretsManager.MaxSeats.HasValue && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats) + if (plan.SecretsManager.MaxSeats.HasValue && plan.SecretsManager.MaxSeats.Value > 0 && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats) { throw new BadRequestException(string.Concat( $"Your plan has a Secrets Manager seat limit of {plan.SecretsManager.MaxSeats}, ", diff --git a/src/Core/Resources/SharedResources.en.resx b/src/Core/Resources/SharedResources.en.resx index 17b4489454..28ae70ca96 100644 --- a/src/Core/Resources/SharedResources.en.resx +++ b/src/Core/Resources/SharedResources.en.resx @@ -389,7 +389,7 @@ If SAML Binding Type is set to artifact, identity provider resolution service URL is required. - If Identity Provider Entity ID is not a URL, single sign on service URL is required. + Single sign on service URL is required. The configured authentication scheme is not valid: "{0}" diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 9728c2e727..75e0c78702 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -26,6 +26,7 @@ using Bit.Core.Vault.Models.Data; using Core.Auth.Enums; using HandlebarsDotNet; using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; namespace Bit.Core.Services; @@ -39,6 +40,7 @@ public class HandlebarsMailService : IMailService private readonly IMailDeliveryService _mailDeliveryService; private readonly IMailEnqueuingService _mailEnqueuingService; private readonly IDistributedCache _distributedCache; + private readonly ILogger _logger; private readonly Dictionary> _templateCache = new(); private bool _registeredHelpersAndPartials = false; @@ -47,12 +49,14 @@ public class HandlebarsMailService : IMailService GlobalSettings globalSettings, IMailDeliveryService mailDeliveryService, IMailEnqueuingService mailEnqueuingService, - IDistributedCache distributedCache) + IDistributedCache distributedCache, + ILogger logger) { _globalSettings = globalSettings; _mailDeliveryService = mailDeliveryService; _mailEnqueuingService = mailEnqueuingService; _distributedCache = distributedCache; + _logger = logger; } public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) @@ -708,6 +712,12 @@ public class HandlebarsMailService : IMailService private async Task ReadSourceAsync(string templateName) { + var diskSource = await ReadSourceFromDiskAsync(templateName); + if (!string.IsNullOrWhiteSpace(diskSource)) + { + return diskSource; + } + var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; var fullTemplateName = $"{Namespace}.{templateName}.hbs"; if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) @@ -721,6 +731,42 @@ public class HandlebarsMailService : IMailService } } + private async Task ReadSourceFromDiskAsync(string templateName) + { + if (!_globalSettings.SelfHosted) + { + return null; + } + try + { + var templateFileSuffix = ".html"; + if (templateName.EndsWith(".txt")) + { + templateFileSuffix = ".txt"; + } + else if (!templateName.EndsWith(".html")) + { + // unexpected suffix + return null; + } + var suffixPosition = templateName.LastIndexOf(templateFileSuffix); + var templateNameNoSuffix = templateName.Substring(0, suffixPosition); + var templatePathNoSuffix = templateNameNoSuffix.Replace(".", "/"); + var diskPath = $"{_globalSettings.MailTemplateDirectory}/{templatePathNoSuffix}{templateFileSuffix}.hbs"; + var directory = Path.GetDirectoryName(diskPath); + if (Directory.Exists(directory) && File.Exists(diskPath)) + { + var fileContents = await File.ReadAllTextAsync(diskPath); + return fileContents; + } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to read mail template from disk."); + } + return null; + } + private async Task RegisterHelpersAndPartialsAsync() { if (_registeredHelpersAndPartials) diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index 03d1776e90..4863baf73e 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -3,6 +3,7 @@ using Bit.Core.Models.BitStripe; using Stripe; +using Stripe.Tax; namespace Bit.Core.Services; @@ -23,6 +24,7 @@ public class StripeAdapter : IStripeAdapter private readonly Stripe.TestHelpers.TestClockService _testClockService; private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; private readonly Stripe.Tax.RegistrationService _taxRegistrationService; + private readonly CalculationService _calculationService; public StripeAdapter() { @@ -41,6 +43,7 @@ public class StripeAdapter : IStripeAdapter _testClockService = new Stripe.TestHelpers.TestClockService(); _customerBalanceTransactionService = new CustomerBalanceTransactionService(); _taxRegistrationService = new Stripe.Tax.RegistrationService(); + _calculationService = new CalculationService(); } public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 546e668093..250daf0007 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -8,6 +8,7 @@ namespace Bit.Core.Settings; public class GlobalSettings : IGlobalSettings { + private string _mailTemplateDirectory; private string _logDirectory; private string _licenseDirectory; @@ -37,6 +38,11 @@ public class GlobalSettings : IGlobalSettings get => BuildDirectory(_licenseDirectory, "/core/licenses"); set => _licenseDirectory = value; } + public virtual string MailTemplateDirectory + { + get => BuildDirectory(_mailTemplateDirectory, "/mail-templates"); + set => _mailTemplateDirectory = value; + } public string LicenseCertificatePassword { get; set; } public virtual string PushRelayBaseUri { get; set; } public virtual string InternalIdentityKey { get; set; } @@ -97,6 +103,7 @@ public class GlobalSettings : IGlobalSettings /// public virtual string SendDefaultHashKey { get; set; } public virtual string PricingUri { get; set; } + public virtual Fido2Settings Fido2 { get; set; } = new Fido2Settings(); public string BuildExternalUri(string explicitValue, string name) { @@ -301,6 +308,9 @@ public class GlobalSettings : IGlobalSettings private string _eventTopicName; private string _integrationTopicName; + public virtual int DefaultMaxConcurrentCalls { get; set; } = 1; + public virtual int DefaultPrefetchCount { get; set; } = 0; + public virtual string EventRepositorySubscriptionName { get; set; } = "events-write-subscription"; public virtual string SlackEventSubscriptionName { get; set; } = "events-slack-subscription"; public virtual string SlackIntegrationSubscriptionName { get; set; } = "integration-slack-subscription"; @@ -763,4 +773,9 @@ public class GlobalSettings : IGlobalSettings { public string VapidPublicKey { get; set; } } + + public class Fido2Settings + { + public HashSet Origins { get; set; } + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs index b034f31f39..2ddc5679d5 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs @@ -230,6 +230,8 @@ public class EventRepository : Repository, IEventRepository eventsTable.Columns.Add(serviceAccountIdColumn); var projectIdColumn = new DataColumn(nameof(e.ProjectId), typeof(Guid)); eventsTable.Columns.Add(projectIdColumn); + var grantedServiceAccountIdColumn = new DataColumn(nameof(e.GrantedServiceAccountId), typeof(Guid)); + eventsTable.Columns.Add(grantedServiceAccountIdColumn); foreach (DataColumn col in eventsTable.Columns) { @@ -263,6 +265,7 @@ public class EventRepository : Repository, IEventRepository row[secretIdColumn] = ev.SecretId.HasValue ? ev.SecretId.Value : DBNull.Value; row[serviceAccountIdColumn] = ev.ServiceAccountId.HasValue ? ev.ServiceAccountId.Value : DBNull.Value; row[projectIdColumn] = ev.ProjectId.HasValue ? ev.ProjectId.Value : DBNull.Value; + row[grantedServiceAccountIdColumn] = ev.GrantedServiceAccountId.HasValue ? ev.GrantedServiceAccountId.Value : DBNull.Value; eventsTable.Rows.Add(row); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs index 76e9b2e912..98f10394f4 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs @@ -12,9 +12,16 @@ public class EventEntityTypeConfiguration : IEntityTypeConfiguration .Property(e => e.Id) .ValueGeneratedNever(); - builder - .HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId }) - .IsClustered(false); + builder.HasKey(e => e.Id) + .IsClustered(); + + var index = builder.HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId }) + .IsClustered(false) + .HasDatabaseName("IX_Event_DateOrganizationIdUserId"); + + SqlServerIndexBuilderExtensions.IncludeProperties( + index, + e => new { e.ServiceAccountId, e.GrantedServiceAccountId }); builder.ToTable(nameof(Event)); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs index 01f3a1fe14..72dc8db386 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs @@ -30,7 +30,7 @@ public class EventReadPageByOrganizationIdServiceAccountIdQuery : IQuery (_beforeDate != null || e.Date <= _endDate) && (_beforeDate == null || e.Date < _beforeDate.Value) && e.OrganizationId == _organizationId && - e.ServiceAccountId == _serviceAccountId + (e.ServiceAccountId == _serviceAccountId || e.GrantedServiceAccountId == _serviceAccountId) orderby e.Date descending select e; return q.Skip(0).Take(_pageOptions.PageSize); diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs new file mode 100644 index 0000000000..0d1cd6a656 --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs @@ -0,0 +1,48 @@ +using Bit.Core.Models.Data; +using Bit.Core.SecretsManager.Entities; +using Event = Bit.Infrastructure.EntityFramework.Models.Event; + +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByServiceAccountQuery : IQuery +{ + private readonly ServiceAccount _serviceAccount; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + _serviceAccount = serviceAccount; + _startDate = startDate; + _endDate = endDate; + _beforeDate = null; + _pageOptions = pageOptions; + } + + public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _serviceAccount = serviceAccount; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate == null || e.Date < _beforeDate.Value) && + ( + (_serviceAccount.OrganizationId == Guid.Empty && !e.OrganizationId.HasValue) || + (_serviceAccount.OrganizationId != Guid.Empty && e.OrganizationId == _serviceAccount.OrganizationId) + ) && + e.GrantedServiceAccountId == _serviceAccount.Id + orderby e.Date descending + select e; + + return q.Take(_pageOptions.PageSize); + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index bd70e27e78..809704edb7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -283,6 +283,9 @@ public class UserRepository : Repository, IUserR var transaction = await dbContext.Database.BeginTransactionAsync(); + MigrateDefaultUserCollectionsToShared(dbContext, [user.Id]); + await dbContext.SaveChangesAsync(); + dbContext.WebAuthnCredentials.RemoveRange(dbContext.WebAuthnCredentials.Where(w => w.UserId == user.Id)); dbContext.Ciphers.RemoveRange(dbContext.Ciphers.Where(c => c.UserId == user.Id)); dbContext.Folders.RemoveRange(dbContext.Folders.Where(f => f.UserId == user.Id)); @@ -314,8 +317,8 @@ public class UserRepository : Repository, IUserR var mappedUser = Mapper.Map(user); dbContext.Users.Remove(mappedUser); - await transaction.CommitAsync(); await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); } } @@ -329,21 +332,30 @@ public class UserRepository : Repository, IUserR var targetIds = users.Select(u => u.Id).ToList(); + MigrateDefaultUserCollectionsToShared(dbContext, targetIds); + await dbContext.SaveChangesAsync(); + await dbContext.WebAuthnCredentials.Where(wa => targetIds.Contains(wa.UserId)).ExecuteDeleteAsync(); await dbContext.Ciphers.Where(c => targetIds.Contains(c.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.Folders.Where(f => targetIds.Contains(f.UserId)).ExecuteDeleteAsync(); await dbContext.AuthRequests.Where(a => targetIds.Contains(a.UserId)).ExecuteDeleteAsync(); await dbContext.Devices.Where(d => targetIds.Contains(d.UserId)).ExecuteDeleteAsync(); - var collectionUsers = from cu in dbContext.CollectionUsers - join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id - where targetIds.Contains(ou.UserId ?? default) - select cu; - dbContext.CollectionUsers.RemoveRange(collectionUsers); - var groupUsers = from gu in dbContext.GroupUsers - join ou in dbContext.OrganizationUsers on gu.OrganizationUserId equals ou.Id - where targetIds.Contains(ou.UserId ?? default) - select gu; - dbContext.GroupUsers.RemoveRange(groupUsers); + await dbContext.CollectionUsers + .Join(dbContext.OrganizationUsers, + cu => cu.OrganizationUserId, + ou => ou.Id, + (cu, ou) => new { CollectionUser = cu, OrganizationUser = ou }) + .Where((joined) => targetIds.Contains(joined.OrganizationUser.UserId ?? default)) + .Select(joined => joined.CollectionUser) + .ExecuteDeleteAsync(); + await dbContext.GroupUsers + .Join(dbContext.OrganizationUsers, + gu => gu.OrganizationUserId, + ou => ou.Id, + (gu, ou) => new { GroupUser = gu, OrganizationUser = ou }) + .Where(joined => targetIds.Contains(joined.OrganizationUser.UserId ?? default)) + .Select(joined => joined.GroupUser) + .ExecuteDeleteAsync(); await dbContext.UserProjectAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.UserServiceAccountAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.OrganizationUsers.Where(ou => targetIds.Contains(ou.UserId ?? default)).ExecuteDeleteAsync(); @@ -354,15 +366,29 @@ public class UserRepository : Repository, IUserR await dbContext.NotificationStatuses.Where(ns => targetIds.Contains(ns.UserId)).ExecuteDeleteAsync(); await dbContext.Notifications.Where(n => targetIds.Contains(n.UserId ?? default)).ExecuteDeleteAsync(); - foreach (var u in users) - { - var mappedUser = Mapper.Map(u); - dbContext.Users.Remove(mappedUser); - } + await dbContext.Users.Where(u => targetIds.Contains(u.Id)).ExecuteDeleteAsync(); - - await transaction.CommitAsync(); await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); + } + } + + private static void MigrateDefaultUserCollectionsToShared(DatabaseContext dbContext, IEnumerable userIds) + { + var defaultCollections = (from c in dbContext.Collections + join cu in dbContext.CollectionUsers on c.Id equals cu.CollectionId + join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id + join u in dbContext.Users on ou.UserId equals u.Id + where userIds.Contains(ou.UserId!.Value) + && c.Type == Core.Enums.CollectionType.DefaultUserCollection + select new { Collection = c, UserEmail = u.Email }) + .ToList(); + + foreach (var item in defaultCollections) + { + item.Collection.Type = Core.Enums.CollectionType.SharedCollection; + item.Collection.DefaultUserCollectionEmail = item.Collection.DefaultUserCollectionEmail ?? item.UserEmail; + item.Collection.RevisionDate = DateTime.UtcNow; } } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index d87f9ab97f..58ce0466c3 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -6,6 +6,8 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; +using Azure.Messaging.ServiceBus; +using Bit.Core; using Bit.Core.AdminConsole.AbilitiesCache; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; @@ -694,8 +696,23 @@ public static class ServiceCollectionExtensions { options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; options.ServerName = "Bitwarden"; - options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; options.TimestampDriftTolerance = 300000; + + if (globalSettings.Fido2?.Origins?.Any() == true) + { + options.Origins = new HashSet(globalSettings.Fido2.Origins); + } + else + { + // Default to allowing the vault domain and chromium browser extension IDs + options.Origins = new HashSet { + globalSettings.BaseServiceUri.Vault, + Constants.BrowserExtensions.ChromeId, + Constants.BrowserExtensions.EdgeId, + Constants.BrowserExtensions.OperaId + }; + } + }); } @@ -855,6 +872,11 @@ public static class ServiceCollectionExtensions configuration: listenerConfiguration, handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.EventPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) @@ -865,6 +887,11 @@ public static class ServiceCollectionExtensions configuration: listenerConfiguration, handler: provider.GetRequiredService>(), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) @@ -927,6 +954,11 @@ public static class ServiceCollectionExtensions configuration: repositoryConfiguration, handler: provider.GetRequiredService(), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = repositoryConfiguration.EventPrefetchCount, + MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) diff --git a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql index 5dc950ffff..831c9f70ee 100644 --- a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql +++ b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql @@ -18,7 +18,7 @@ BEGIN AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate) AND (@BeforeDate IS NULL OR [Date] < @BeforeDate) AND [OrganizationId] = @OrganizationId - AND [ServiceAccountId] = @ServiceAccountId + AND ([ServiceAccountId] = @ServiceAccountId OR [GrantedServiceAccountId] = @ServiceAccountId) ORDER BY [Date] DESC OFFSET 0 ROWS FETCH NEXT @PageSize ROWS ONLY diff --git a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql new file mode 100644 index 0000000000..c429a4a064 --- /dev/null +++ b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql @@ -0,0 +1,45 @@ +CREATE PROCEDURE [dbo].[Event_ReadPageByServiceAccountId] + @GrantedServiceAccountId UNIQUEIDENTIFIER, + @StartDate DATETIME2(7), + @EndDate DATETIME2(7), + @BeforeDate DATETIME2(7), + @PageSize INT +AS +BEGIN + SET NOCOUNT ON + + SELECT + e.Id, + e.Date, + e.Type, + e.UserId, + e.OrganizationId, + e.InstallationId, + e.ProviderId, + e.CipherId, + e.CollectionId, + e.PolicyId, + e.GroupId, + e.OrganizationUserId, + e.ProviderUserId, + e.ProviderOrganizationId, + e.DeviceType, + e.IpAddress, + e.ActingUserId, + e.SystemUser, + e.DomainName, + e.SecretId, + e.ServiceAccountId, + e.ProjectId, + e.GrantedServiceAccountId + FROM + [dbo].[EventView] e + WHERE + [Date] >= @StartDate + AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate) + AND (@BeforeDate IS NULL OR [Date] < @BeforeDate) + AND [GrantedServiceAccountId] = @GrantedServiceAccountId + ORDER BY [Date] DESC + OFFSET 0 ROWS + FETCH NEXT @PageSize ROWS ONLY +END diff --git a/src/Sql/dbo/Stored Procedures/Event_Create.sql b/src/Sql/dbo/Stored Procedures/Event_Create.sql index 89971bd56f..0466bc1a69 100644 --- a/src/Sql/dbo/Stored Procedures/Event_Create.sql +++ b/src/Sql/dbo/Stored Procedures/Event_Create.sql @@ -20,7 +20,8 @@ @DomainName VARCHAR(256), @SecretId UNIQUEIDENTIFIER = null, @ServiceAccountId UNIQUEIDENTIFIER = null, - @ProjectId UNIQUEIDENTIFIER = null + @ProjectId UNIQUEIDENTIFIER = null, + @GrantedServiceAccountId UNIQUEIDENTIFIER = null AS BEGIN SET NOCOUNT ON @@ -48,7 +49,8 @@ BEGIN [DomainName], [SecretId], [ServiceAccountId], - [ProjectId] + [ProjectId], + [GrantedServiceAccountId] ) VALUES ( @@ -73,6 +75,7 @@ BEGIN @DomainName, @SecretId, @ServiceAccountId, - @ProjectId + @ProjectId, + @GrantedServiceAccountId ) END diff --git a/src/Sql/dbo/Stored Procedures/User_DeleteById.sql b/src/Sql/dbo/Stored Procedures/User_DeleteById.sql index 0608982e37..6377166e17 100644 --- a/src/Sql/dbo/Stored Procedures/User_DeleteById.sql +++ b/src/Sql/dbo/Stored Procedures/User_DeleteById.sql @@ -52,6 +52,16 @@ BEGIN WHERE [UserId] = @Id + -- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records + DECLARE @OrgUserIds [dbo].[GuidIdArray] + INSERT INTO @OrgUserIds (Id) + SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] = @Id + + IF EXISTS (SELECT 1 FROM @OrgUserIds) + BEGIN + EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds + END + -- Delete collection users DELETE CU diff --git a/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql b/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql index 97ab955f83..cdf3dd7d3a 100644 --- a/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql +++ b/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql @@ -66,6 +66,16 @@ BEGIN WHERE [UserId] IN (SELECT * FROM @ParsedIds) + -- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records + DECLARE @OrgUserIds [dbo].[GuidIdArray] + INSERT INTO @OrgUserIds (Id) + SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] IN (SELECT * FROM @ParsedIds) + + IF EXISTS (SELECT 1 FROM @OrgUserIds) + BEGIN + EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds + END + -- Delete collection users DELETE CU diff --git a/src/Sql/dbo/Tables/Event.sql b/src/Sql/dbo/Tables/Event.sql index 6dfb4392a0..ea0dda5661 100644 --- a/src/Sql/dbo/Tables/Event.sql +++ b/src/Sql/dbo/Tables/Event.sql @@ -21,11 +21,12 @@ [SecretId] UNIQUEIDENTIFIER NULL, [ServiceAccountId] UNIQUEIDENTIFIER NULL, [ProjectId] UNIQUEIDENTIFIER NULL, + [GrantedServiceAccountId] UNIQUEIDENTIFIER NULL, CONSTRAINT [PK_Event] PRIMARY KEY CLUSTERED ([Id] ASC) ); GO CREATE NONCLUSTERED INDEX [IX_Event_DateOrganizationIdUserId] - ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC); + ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC) INCLUDE ([ServiceAccountId], [GrantedServiceAccountId]); diff --git a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs index 9bbc8a77c0..376fb01493 100644 --- a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs @@ -1,12 +1,18 @@ -using Bit.Api.AdminConsole.Controllers; +#nullable enable + +using Bit.Api.AdminConsole.Controllers; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Routing; +using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; @@ -16,98 +22,312 @@ namespace Bit.Api.Test.AdminConsole.Controllers; [SutProviderCustomize] public class SlackIntegrationControllerTests { + private const string _slackToken = "xoxb-test-token"; + private const string _validSlackCode = "A_test_code"; + [Theory, BitAutoData] - public async Task CreateAsync_AllParamsProvided_Succeeds(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + OrganizationIntegration integration) { - var token = "xoxb-test-token"; + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) - .Returns(token); + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(callInfo => callInfo.Arg()); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, "A_test_code"); + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + var requestAction = await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()); await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); + .UpsertAsync(Arg.Any()); Assert.IsType(requestAction); } [Theory, BitAutoData] - public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, string.Empty)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(string.Empty, state.ToString())); } [Theory, BitAutoData] - public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) .Returns(string.Empty); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code")); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); } [Theory, BitAutoData] - public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_StateEmpty_ThrowsNotFound( + SutProvider sutProvider) { - var token = "xoxb-test-token"; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) - .Returns(token); + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code")); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, String.Empty)); } [Theory, BitAutoData] - public async Task RedirectAsync_Success(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_StateExpired_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) { - var expectedUrl = $"https://localhost/{organizationId}"; + var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc)); + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + timeProvider.Advance(TimeSpan.FromMinutes(30)); + + sutProvider.SetDependency(timeProvider); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonexistentIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasWrongOgranizationHash_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration wrongOrgIntegration) + { + wrongOrgIntegration.Id = integration.Id; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(expectedUrl); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(wrongOrgIntegration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonEmptyIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Slack; + integration.Configuration = "{}"; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonSlackIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Hec; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_Success( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(integration.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + var requestAction = await sutProvider.Sut.RedirectAsync(integration.OrganizationId); + + Assert.IsType(requestAction); + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Any()); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithNullConfig_Success( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Slack; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); var requestAction = await sutProvider.Sut.RedirectAsync(organizationId); - var redirectResult = Assert.IsType(requestAction); - Assert.Equal(expectedUrl, redirectResult.Url); + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + Assert.IsType(requestAction); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); } [Theory, BitAutoData] - public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) { + integration.OrganizationId = organizationId; + integration.Configuration = "{}"; + integration.Type = IntegrationType.Slack; + var expectedUrl = "https://localhost/"; + sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(string.Empty); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .Returns(expectedUrl); sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(string.Empty); await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); } @@ -116,14 +336,9 @@ public class SlackIntegrationControllerTests public async Task RedirectAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(string.Empty); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); } diff --git a/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs new file mode 100644 index 0000000000..babdf3894d --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs @@ -0,0 +1,117 @@ +#nullable enable + +using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response.Organizations; + +public class OrganizationIntegrationResponseModelTests +{ + [Theory, BitAutoData] + public void Status_CloudBillingSync_AlwaysNotApplicable(OrganizationIntegration oi) + { + oi.Type = IntegrationType.CloudBillingSync; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + } + + [Theory, BitAutoData] + public void Status_Scim_AlwaysNotApplicable(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Scim; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + } + + [Theory, BitAutoData] + public void Status_Slack_NullConfig_ReturnsInitiated(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Slack; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Initiated, model.Status); + } + + [Theory, BitAutoData] + public void Status_Slack_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Slack; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Webhook_AlwaysCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Webhook; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Hec_NullConfig_ReturnsInvalid(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Hec; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status); + } + + [Theory, BitAutoData] + public void Status_Hec_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Hec; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Datadog_NullConfig_ReturnsInvalid(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Datadog; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status); + } + + [Theory, BitAutoData] + public void Status_Datadog_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Datadog; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } +} diff --git a/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs b/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs new file mode 100644 index 0000000000..8348ba885d --- /dev/null +++ b/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs @@ -0,0 +1,313 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Auth.Models.Request.Organizations; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Services; +using Bit.Core.Sso; +using Microsoft.Extensions.Localization; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Auth.Models.Request; + +public class OrganizationSsoRequestModelTests +{ + [Fact] + public void ToSsoConfig_WithOrganizationId_CreatesNewSsoConfig() + { + // Arrange + var organizationId = Guid.NewGuid(); + var model = new OrganizationSsoRequestModel + { + Enabled = true, + Identifier = "test-identifier", + Data = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = "https://example.com", + ClientId = "test-client", + ClientSecret = "test-secret" + } + }; + + // Act + var result = model.ToSsoConfig(organizationId); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.OrganizationId); + Assert.True(result.Enabled); + } + + [Fact] + public void ToSsoConfig_WithExistingConfig_UpdatesExistingConfig() + { + // Arrange + var organizationId = Guid.NewGuid(); + var existingConfig = new SsoConfig + { + Id = 1, + OrganizationId = organizationId, + Enabled = false + }; + + var model = new OrganizationSsoRequestModel + { + Enabled = true, + Identifier = "updated-identifier", + Data = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = "test-entity", + IdpSingleSignOnServiceUrl = "https://sso.example.com" + } + }; + + // Act + var result = model.ToSsoConfig(existingConfig); + + // Assert + Assert.Same(existingConfig, result); + Assert.Equal(organizationId, result.OrganizationId); + Assert.True(result.Enabled); + } +} + +public class SsoConfigurationDataRequestTests +{ + private readonly TestI18nService _i18nService; + private readonly ValidationContext _validationContext; + + public SsoConfigurationDataRequestTests() + { + _i18nService = new TestI18nService(); + var serviceProvider = Substitute.For(); + serviceProvider.GetService(typeof(II18nService)).Returns(_i18nService); + _validationContext = new ValidationContext(new object(), serviceProvider, null); + } + + [Fact] + public void ToConfigurationData_MapsProperties() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + MemberDecryptionType = MemberDecryptionType.KeyConnector, + Authority = "https://authority.example.com", + ClientId = "test-client-id", + ClientSecret = "test-client-secret", + IdpX509PublicCert = "-----BEGIN CERTIFICATE-----\nMIIC...test\n-----END CERTIFICATE-----", + SpOutboundSigningAlgorithm = null // Test default + }; + + // Act + var result = model.ToConfigurationData(); + + // Assert + Assert.Equal(SsoType.OpenIdConnect, result.ConfigType); + Assert.Equal(MemberDecryptionType.KeyConnector, result.MemberDecryptionType); + Assert.Equal("https://authority.example.com", result.Authority); + Assert.Equal("test-client-id", result.ClientId); + Assert.Equal("test-client-secret", result.ClientSecret); + Assert.Equal("MIIC...test", result.IdpX509PublicCert); // PEM headers stripped + Assert.Equal(SamlSigningAlgorithms.Sha256, result.SpOutboundSigningAlgorithm); // Default applied + Assert.Null(result.IdpArtifactResolutionServiceUrl); // Always null + } + + [Fact] + public void KeyConnectorEnabled_Setter_UpdatesMemberDecryptionType() + { + // Arrange + var model = new SsoConfigurationDataRequest(); + + // Act & Assert +#pragma warning disable CS0618 // Type or member is obsolete + model.KeyConnectorEnabled = true; + Assert.Equal(MemberDecryptionType.KeyConnector, model.MemberDecryptionType); + + model.KeyConnectorEnabled = false; + Assert.Equal(MemberDecryptionType.MasterPassword, model.MemberDecryptionType); +#pragma warning restore CS0618 // Type or member is obsolete + } + + // Validation Tests + [Fact] + public void Validate_OpenIdConnect_ValidData_NoErrors() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = "https://example.com", + ClientId = "test-client", + ClientSecret = "test-secret" + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Empty(results); + } + + [Theory] + [InlineData("", "test-client", "test-secret", "AuthorityValidationError")] + [InlineData("https://example.com", "", "test-secret", "ClientIdValidationError")] + [InlineData("https://example.com", "test-client", "", "ClientSecretValidationError")] + public void Validate_OpenIdConnect_MissingRequiredFields_ReturnsErrors(string authority, string clientId, string clientSecret, string expectedError) + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = authority, + ClientId = clientId, + ClientSecret = clientSecret + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Single(results); + Assert.Equal(expectedError, results[0].ErrorMessage); + } + + [Fact] + public void Validate_Saml2_ValidData_NoErrors() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = "https://idp.example.com", + IdpSingleSignOnServiceUrl = "https://sso.example.com", + IdpSingleLogoutServiceUrl = "https://logout.example.com" + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Empty(results); + } + + [Theory] + [InlineData("", "https://sso.example.com", "IdpEntityIdValidationError")] + [InlineData("not-a-valid-uri", "", "IdpSingleSignOnServiceUrlValidationError")] + public void Validate_Saml2_MissingRequiredFields_ReturnsErrors(string entityId, string signOnUrl, string expectedError) + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = entityId, + IdpSingleSignOnServiceUrl = signOnUrl + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Contains(results, r => r.ErrorMessage == expectedError); + } + + [Theory] + [InlineData("not-a-url")] + [InlineData("ftp://example.com")] + [InlineData("https://example.com