diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 65780bdb63..597085d97d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,6 +36,7 @@ util/Setup/** @bitwarden/dept-bre @bitwarden/team-platform-dev # UIF src/Core/MailTemplates/Mjml @bitwarden/team-ui-foundation # Teams are expected to own sub-directories of this project +src/Core/MailTemplates/Mjml/.mjmlconfig # This change allows teams to add components within their own subdirectories without requiring a code review from UIF. # Auth team **/Auth @bitwarden/team-auth-dev diff --git a/.github/ISSUE_TEMPLATE/bw-unified.yml b/.github/ISSUE_TEMPLATE/bw-lite.yml similarity index 88% rename from .github/ISSUE_TEMPLATE/bw-unified.yml rename to .github/ISSUE_TEMPLATE/bw-lite.yml index 240b1faa72..0c43fa5835 100644 --- a/.github/ISSUE_TEMPLATE/bw-unified.yml +++ b/.github/ISSUE_TEMPLATE/bw-lite.yml @@ -1,6 +1,6 @@ -name: Bitwarden Unified Deployment Bug Report +name: Bitwarden lite Deployment Bug Report description: File a bug report -labels: [bug, bw-unified-deploy] +labels: [bug, bw-lite-deploy] body: - type: markdown attributes: @@ -70,15 +70,6 @@ body: mariadb:10 # Postgres Example postgres:14 - - type: textarea - id: epic-label - attributes: - label: Issue-Link - description: Link to our pinned issue, tracking all Bitwarden Unified - value: | - https://github.com/bitwarden/server/issues/2480 - validations: - required: true - type: checkboxes id: issue-tracking-info attributes: diff --git a/.github/renovate.json5 b/.github/renovate.json5 index bc377ed46c..074b4dde2b 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -42,8 +42,9 @@ dependencyDashboardApproval: false, }, { - matchSourceUrls: ["https://github.com/bitwarden/sdk-internal"], + matchPackageNames: ["https://github.com/bitwarden/sdk-internal.git"], groupName: "sdk-internal", + dependencyDashboardApproval: true }, { matchManagers: ["dockerfile", "docker-compose"], @@ -63,7 +64,6 @@ }, { matchPackageNames: [ - "Azure.Extensions.AspNetCore.DataProtection.Blobs", "DuoUniversal", "Fido2.AspNet", "Duende.IdentityServer", @@ -90,11 +90,7 @@ "Microsoft.AspNetCore.Mvc.Testing", "Newtonsoft.Json", "NSubstitute", - "Sentry.Serilog", - "Serilog.AspNetCore", - "Serilog.Extensions.Logging", "Serilog.Extensions.Logging.File", - "Serilog.Sinks.SyslogMessages", "Stripe.net", "Swashbuckle.AspNetCore", "Swashbuckle.AspNetCore.SwaggerGen", @@ -141,6 +137,7 @@ "AspNetCoreRateLimit", "AspNetCoreRateLimit.Redis", "Azure.Data.Tables", + "Azure.Extensions.AspNetCore.DataProtection.Blobs", "Azure.Messaging.EventGrid", "Azure.Messaging.ServiceBus", "Azure.Storage.Blobs", diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2d92c68b93..9b457b9d56 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,7 +22,7 @@ env: jobs: lint: name: Lint - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 steps: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -38,7 +38,7 @@ jobs: build-artifacts: name: Build Docker images - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - lint outputs: @@ -49,7 +49,6 @@ jobs: timeout-minutes: 45 strategy: fail-fast: false - max-parallel: 5 matrix: include: - project_name: Admin @@ -186,13 +185,6 @@ jobs: - name: Log in to ACR - production subscription run: az acr login -n bitwardenprod - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat - uses: bitwarden/gh-actions/get-keyvault-secrets@main - with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" - ########## Generate image tag and build Docker image ########## - name: Generate Docker image tag id: tag @@ -251,8 +243,6 @@ jobs: linux/arm64 push: true tags: ${{ steps.image-tags.outputs.tags }} - secrets: | - "GH_PAT=${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}" - name: Install Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' @@ -281,7 +271,7 @@ jobs: output-format: sarif - name: Upload Grype results to GitHub - uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8 + uses: github/codeql-action/upload-sarif@e12f0178983d466f2f6028f5cc7a6d786fd97f4b # v4.31.4 with: sarif_file: ${{ steps.container-scan.outputs.sarif }} sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }} @@ -292,7 +282,7 @@ jobs: upload: name: Upload - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: build-artifacts permissions: id-token: write @@ -410,7 +400,7 @@ jobs: build-mssqlmigratorutility: name: Build MSSQL migrator utility - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - lint defaults: @@ -467,7 +457,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - build-artifacts permissions: @@ -480,25 +470,34 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main - - name: Trigger self-host build + - name: Generate GH App token + uses: actions/create-github-app-token@67018539274d69449ef7c02e8e71183d1719ab42 # v2.1.4 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + owner: ${{ github.repository_owner }} + repositories: self-host + + - name: Trigger Bitwarden lite build uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', repo: 'self-host', - workflow_id: 'build-unified.yml', + workflow_id: 'build-bitwarden-lite.yml', ref: 'main', inputs: { server_branch: process.env.GITHUB_REF @@ -521,20 +520,29 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main + - name: Generate GH App token + uses: actions/create-github-app-token@67018539274d69449ef7c02e8e71183d1719ab42 # v2.1.4 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + owner: ${{ github.repository_owner }} + repositories: devops + - name: Trigger k8s deploy uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', diff --git a/.github/workflows/repository-management.yml b/.github/workflows/repository-management.yml index 92452102cf..74823c34b5 100644 --- a/.github/workflows/repository-management.yml +++ b/.github/workflows/repository-management.yml @@ -22,9 +22,7 @@ on: required: false type: string -permissions: - pull-requests: write - contents: write +permissions: {} jobs: setup: @@ -32,6 +30,7 @@ jobs: runs-on: ubuntu-24.04 outputs: branch: ${{ steps.set-branch.outputs.branch }} + permissions: {} steps: - name: Set branch id: set-branch @@ -89,6 +88,7 @@ jobs: with: app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + permission-contents: write - name: Check out branch uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -212,6 +212,7 @@ jobs: with: app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + permission-contents: write - name: Check out target ref uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -240,10 +241,5 @@ jobs: move_edd_db_scripts: name: Move EDD database scripts needs: cut_branch - permissions: - actions: read - contents: write - id-token: write - pull-requests: write + permissions: {} uses: ./.github/workflows/_move_edd_db_scripts.yml - secrets: inherit diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index fb1c18b158..b0d0c076a1 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -62,7 +62,7 @@ jobs: docker compose --profile mssql --profile postgres --profile mysql up -d shell: pwsh - - name: Add MariaDB for unified + - name: Add MariaDB for Bitwarden lite # Use a different port than MySQL run: | docker run --detach --name mariadb --env MARIADB_ROOT_PASSWORD=mariadb-password -p 4306:3306 mariadb:10 @@ -133,7 +133,7 @@ jobs: # Default Sqlite BW_TEST_DATABASES__3__TYPE: "Sqlite" BW_TEST_DATABASES__3__CONNECTIONSTRING: "Data Source=${{ runner.temp }}/test.db" - # Unified MariaDB + # Bitwarden lite MariaDB BW_TEST_DATABASES__4__TYPE: "MySql" BW_TEST_DATABASES__4__CONNECTIONSTRING: "server=localhost;port=4306;uid=root;pwd=mariadb-password;database=vault_dev;Allow User Variables=true" run: dotnet test --logger "trx;LogFileName=infrastructure-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" @@ -262,3 +262,26 @@ jobs: working-directory: "dev" run: docker compose down shell: pwsh + + validate-migration-naming: + name: Validate new migration naming and order + runs-on: ubuntu-22.04 + + steps: + - name: Check out repo + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Validate new migrations for pull request + if: github.event_name == 'pull_request' + run: | + git fetch origin main:main + pwsh dev/verify_migrations.ps1 -BaseRef main + shell: pwsh + + - name: Validate new migrations for push + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + run: pwsh dev/verify_migrations.ps1 -BaseRef HEAD~1 + shell: pwsh diff --git a/.gitignore b/.gitignore index 60fc894285..db8cb50f84 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,7 @@ bitwarden_license/src/Sso/Sso.zip /identity.json /api.json /api.public.json +.serena/ # Serena .serena/ diff --git a/Directory.Build.props b/Directory.Build.props index 4511202024..db3ccf40f5 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.11.0 + 2025.12.2 Bit.$(MSBuildProjectName) enable @@ -16,7 +16,7 @@ - 17.8.0 + 18.0.1 2.6.6 diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 994b305349..12d370395c 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -113,7 +113,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv await _providerBillingService.CreateCustomerForClientOrganization(provider, organization); } - var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + var customer = await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Description = string.Empty, Email = organization.BillingEmail, @@ -138,7 +138,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await _stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; organization.Status = OrganizationStatusType.Created; @@ -148,27 +148,26 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv } else if (organization.IsStripeEnabled()) { - var subscription = await _stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions + var subscription = await _stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer"] }); - if (subscription.Status is StripeConstants.SubscriptionStatus.Canceled or StripeConstants.SubscriptionStatus.IncompleteExpired) { return; } - await _stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail }); if (subscription.Customer.Discount?.Coupon != null) { - await _stripeAdapter.CustomerDeleteDiscountAsync(subscription.CustomerId); + await _stripeAdapter.DeleteCustomerDiscountAsync(subscription.CustomerId); } - await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, DaysUntilDue = 30, diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index 89ef251fd6..4e8a23cf4e 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -9,12 +9,16 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -59,6 +63,7 @@ public class ProviderService : IProviderService private readonly IProviderBillingService _providerBillingService; private readonly IPricingClient _pricingClient; private readonly IProviderClientOrganizationSignUpCommand _providerClientOrganizationSignUpCommand; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, @@ -68,7 +73,8 @@ public class ProviderService : IProviderService ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService, IDataProtectorTokenFactory providerDeleteTokenDataFactory, IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient, - IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand) + IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand, + IPolicyRequirementQuery policyRequirementQuery) { _providerRepository = providerRepository; _providerUserRepository = providerUserRepository; @@ -89,6 +95,7 @@ public class ProviderService : IProviderService _providerBillingService = providerBillingService; _pricingClient = pricingClient; _providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand; + _policyRequirementQuery = policyRequirementQuery; } public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) @@ -116,6 +123,18 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid owner."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(ownerUserId); + + if (organizationAutoConfirmPolicyRequirement + .CannotCreateProvider()) + { + throw new BadRequestException(new UserCannotJoinProvider().Message); + } + } + var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); @@ -248,6 +267,18 @@ public class ProviderService : IProviderService throw new BadRequestException("User email does not match invite."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(user.Id); + + if (organizationAutoConfirmPolicyRequirement + .CannotJoinProvider()) + { + throw new BadRequestException(new UserCannotJoinProvider().Message); + } + } + providerUser.Status = ProviderUserStatusType.Accepted; providerUser.UserId = user.Id; providerUser.Email = null; @@ -293,6 +324,19 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid user."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(user.Id); + + if (organizationAutoConfirmPolicyRequirement + .CannotJoinProvider()) + { + result.Add(Tuple.Create(providerUser, new UserCannotJoinProvider().Message)); + continue; + } + } + providerUser.Status = ProviderUserStatusType.Confirmed; providerUser.Key = keys[providerUser.Id]; providerUser.Email = null; @@ -427,7 +471,7 @@ public class ProviderService : IProviderService if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = provider.BillingEmail }); @@ -487,7 +531,7 @@ public class ProviderService : IProviderService private async Task GetSubscriptionItemAsync(string subscriptionId, string oldPlanId) { - var subscriptionDetails = await _stripeAdapter.SubscriptionGetAsync(subscriptionId); + var subscriptionDetails = await _stripeAdapter.GetSubscriptionAsync(subscriptionId); return subscriptionDetails.Items.Data.FirstOrDefault(item => item.Price.Id == oldPlanId); } @@ -497,7 +541,7 @@ public class ProviderService : IProviderService { if (subscriptionItem.Price.Id != extractedPlanType) { - await _stripeAdapter.SubscriptionUpdateAsync(subscriptionItem.Subscription, + await _stripeAdapter.UpdateSubscriptionAsync(subscriptionItem.Subscription, new Stripe.SubscriptionUpdateOptions { Items = new List diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs index cc77797307..e140a13841 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -76,8 +75,8 @@ public class GetProviderWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs index 8e8a89ae58..ce2f7a941f 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs @@ -101,7 +101,7 @@ public class BusinessUnitConverter( providerUser.Status = ProviderUserStatusType.Confirmed; // Stripe requires that we clear all the custom fields from the invoice settings if we want to replace them. - await stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -116,7 +116,7 @@ public class BusinessUnitConverter( ["convertedFrom"] = organization.Id.ToString() }; - var updateCustomer = stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + var updateCustomer = stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -148,7 +148,7 @@ public class BusinessUnitConverter( // Replace the existing password manager price with the new business unit price. var updateSubscription = - stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index e352297f1e..7042a531d0 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -61,11 +61,11 @@ public class ProviderBillingService( Organization organization, string key) { - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = - await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, + await stripeAdapter.CancelSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionCancelOptions { CancellationDetails = new SubscriptionCancellationDetailsOptions @@ -83,7 +83,7 @@ public class ProviderBillingService( if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft) { - await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, + await stripeAdapter.FinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); } @@ -138,7 +138,7 @@ public class ProviderBillingService( if (clientCustomer.Balance != 0) { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, + await stripeAdapter.CreateCustomerBalanceTransactionAsync(provider.GatewayCustomerId, new CustomerBalanceTransactionCreateOptions { Amount = clientCustomer.Balance, @@ -187,7 +187,7 @@ public class ProviderBillingService( ] }; - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan @@ -275,7 +275,7 @@ public class ProviderBillingService( customerCreateOptions.TaxExempt = TaxExempt.Reverse; } - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; @@ -525,7 +525,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) @@ -558,7 +558,7 @@ public class ProviderBillingService( try { - return await stripeAdapter.CustomerCreateAsync(options); + return await stripeAdapter.CreateCustomerAsync(options); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid) { @@ -580,7 +580,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); - await stripeAdapter.SetupIntentCancel(setupIntentId, + await stripeAdapter.CancelSetupIntentAsync(setupIntentId, new SetupIntentCancelOptions { CancellationReason = "abandoned" }); await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); break; @@ -638,7 +638,7 @@ public class ProviderBillingService( var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntent = !string.IsNullOrEmpty(setupIntentId) - ? await stripeAdapter.SetupIntentGet(setupIntentId, + ? await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }) : null; @@ -673,7 +673,7 @@ public class ProviderBillingService( try { - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (subscription is { @@ -708,7 +708,7 @@ public class ProviderBillingService( subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource), subscriberService.UpdateTaxInformation(provider, taxInformation)); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically }); } @@ -791,11 +791,49 @@ public class ProviderBillingService( if (subscriptionItemOptionsList.Count > 0) { - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } + public async Task UpdateProviderNameAndEmail(Provider provider) + { + if (string.IsNullOrWhiteSpace(provider.GatewayCustomerId)) + { + logger.LogWarning( + "Provider ({ProviderId}) has no Stripe customer to update", + provider.Id); + return; + } + + var newDisplayName = provider.DisplayName(); + + // Provider.DisplayName() can return null - handle gracefully + if (string.IsNullOrWhiteSpace(newDisplayName)) + { + logger.LogWarning( + "Provider ({ProviderId}) has no name to update in Stripe", + provider.Id); + return; + } + + await stripeAdapter.UpdateCustomerAsync(provider.GatewayCustomerId, + new CustomerUpdateOptions + { + Email = provider.BillingEmail, + Description = newDisplayName, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = newDisplayName + }] + }, + }); + } + private Func CurrySeatScalingUpdate( Provider provider, ProviderPlan providerPlan, @@ -807,7 +845,7 @@ public class ProviderBillingService( var item = subscription.Items.First(item => item.Price.Id == priceId); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs new file mode 100644 index 0000000000..22421f9921 --- /dev/null +++ b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs @@ -0,0 +1,94 @@ +using AutoMapper; +using Bit.Core.SecretsManager.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.SecretsManager.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Commercial.Infrastructure.EntityFramework.SecretsManager.Repositories; + +public class SecretVersionRepository : Repository, ISecretVersionRepository +{ + public SecretVersionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, db => db.SecretVersion) + { } + + public override async Task GetByIdAsync(Guid id) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var secretVersion = await dbContext.SecretVersion + .Where(sv => sv.Id == id) + .FirstOrDefaultAsync(); + return Mapper.Map(secretVersion); + } + + public async Task> GetManyBySecretIdAsync(Guid secretId) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var secretVersions = await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretId) + .OrderByDescending(sv => sv.VersionDate) + .ToListAsync(); + return Mapper.Map>(secretVersions); + } + + public async Task> GetManyByIdsAsync(IEnumerable ids) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var versionIds = ids.ToList(); + var secretVersions = await dbContext.SecretVersion + .Where(sv => versionIds.Contains(sv.Id)) + .OrderByDescending(sv => sv.VersionDate) + .ToListAsync(); + return Mapper.Map>(secretVersions); + } + + public override async Task CreateAsync(Core.SecretsManager.Entities.SecretVersion secretVersion) + { + const int maxVersionsToKeep = 10; + + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + + await using var transaction = await dbContext.Database.BeginTransactionAsync(); + + // Get the IDs of the most recent (maxVersionsToKeep - 1) versions to keep + var versionsToKeepIds = await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretVersion.SecretId) + .OrderByDescending(sv => sv.VersionDate) + .Take(maxVersionsToKeep - 1) + .Select(sv => sv.Id) + .ToListAsync(); + + // Delete all versions for this secret that are not in the "keep" list + if (versionsToKeepIds.Any()) + { + await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretVersion.SecretId && !versionsToKeepIds.Contains(sv.Id)) + .ExecuteDeleteAsync(); + } + + secretVersion.SetNewId(); + var entity = Mapper.Map(secretVersion); + + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); + + return secretVersion; + } + + public async Task DeleteManyByIdAsync(IEnumerable ids) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + + var secretVersionIds = ids.ToList(); + await dbContext.SecretVersion + .Where(sv => secretVersionIds.Contains(sv.Id)) + .ExecuteDeleteAsync(); + } +} diff --git a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs index d6c8848079..ac52c40ba6 100644 --- a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs @@ -10,6 +10,7 @@ public static class SecretsManagerEfServiceCollectionExtensions { services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); } diff --git a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs index e3c290c85f..88d6858cb8 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs @@ -61,17 +61,15 @@ public class GroupsController : Controller [HttpGet("")] public async Task Get( Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) + [FromQuery] GetGroupsQueryParamModel model) { - var groupsListQueryResult = await _getGroupsListQuery.GetGroupsListAsync(organizationId, filter, count, startIndex); + var groupsListQueryResult = await _getGroupsListQuery.GetGroupsListAsync(organizationId, model); var scimListResponseModel = new ScimListResponseModel { Resources = groupsListQueryResult.groupList.Select(g => new ScimGroupResponseModel(g)).ToList(), - ItemsPerPage = count.GetValueOrDefault(groupsListQueryResult.groupList.Count()), + ItemsPerPage = model.Count, TotalResults = groupsListQueryResult.totalResults, - StartIndex = startIndex.GetValueOrDefault(1), + StartIndex = model.StartIndex, }; return Ok(scimListResponseModel); } diff --git a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs index afbfa50bb4..91d79542b5 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; diff --git a/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs b/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs index cc6546700b..f0a561a29f 100644 --- a/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs +++ b/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Scim.Groups.Interfaces; +using Bit.Scim.Models; namespace Bit.Scim.Groups; @@ -16,10 +17,16 @@ public class GetGroupsListQuery : IGetGroupsListQuery _groupRepository = groupRepository; } - public async Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, string filter, int? count, int? startIndex) + public async Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync( + Guid organizationId, GetGroupsQueryParamModel groupQueryParams) { string nameFilter = null; string externalIdFilter = null; + + int count = groupQueryParams.Count; + int startIndex = groupQueryParams.StartIndex; + string filter = groupQueryParams.Filter; + if (!string.IsNullOrWhiteSpace(filter)) { if (filter.StartsWith("displayName eq ")) @@ -53,11 +60,11 @@ public class GetGroupsListQuery : IGetGroupsListQuery } totalResults = groupList.Count; } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + else if (string.IsNullOrWhiteSpace(filter)) { groupList = groups.OrderBy(g => g.Name) - .Skip(startIndex.Value - 1) - .Take(count.Value) + .Skip(startIndex - 1) + .Take(count) .ToList(); totalResults = groups.Count; } diff --git a/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs b/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs index 07ff044701..4b4ba09e1d 100644 --- a/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs +++ b/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Scim.Models; namespace Bit.Scim.Groups.Interfaces; public interface IGetGroupsListQuery { - Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, string filter, int? count, int? startIndex); + Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, GetGroupsQueryParamModel model); } diff --git a/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs b/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs new file mode 100644 index 0000000000..5389727917 --- /dev/null +++ b/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs @@ -0,0 +1,14 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Scim.Models; + +public class GetGroupsQueryParamModel +{ + public string Filter { get; init; } = string.Empty; + + [Range(1, int.MaxValue)] + public int Count { get; init; } = 50; + + [Range(1, int.MaxValue)] + public int StartIndex { get; init; } = 1; +} diff --git a/bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs b/bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs similarity index 91% rename from bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs rename to bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs index 27d7b6d9a1..cd50dbca61 100644 --- a/bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs +++ b/bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs @@ -1,5 +1,7 @@ using System.ComponentModel.DataAnnotations; +namespace Bit.Scim.Models; + public class GetUsersQueryParamModel { public string Filter { get; init; } = string.Empty; diff --git a/bitwarden_license/src/Scim/Program.cs b/bitwarden_license/src/Scim/Program.cs index 92f12f59dd..02f2e00d32 100644 --- a/bitwarden_license/src/Scim/Program.cs +++ b/bitwarden_license/src/Scim/Program.cs @@ -11,21 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.ScimSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/bitwarden_license/src/Scim/Startup.cs b/bitwarden_license/src/Scim/Startup.cs index edbbf34aea..2a84faa8dd 100644 --- a/bitwarden_license/src/Scim/Startup.cs +++ b/bitwarden_license/src/Scim/Startup.cs @@ -94,11 +94,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs b/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs index a734635ebf..c7085eb6b9 100644 --- a/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs +++ b/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs @@ -3,6 +3,7 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Scim.Models; using Bit.Scim.Users.Interfaces; namespace Bit.Scim.Users; diff --git a/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs b/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs index f584cb8e7b..04133c89eb 100644 --- a/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs +++ b/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs @@ -1,4 +1,5 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Scim.Models; namespace Bit.Scim.Users.Interfaces; diff --git a/bitwarden_license/src/Scim/Users/PatchUserCommand.cs b/bitwarden_license/src/Scim/Users/PatchUserCommand.cs index 6c983611ee..474557a9cb 100644 --- a/bitwarden_license/src/Scim/Users/PatchUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PatchUserCommand.cs @@ -1,5 +1,5 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 5b4a0c29cd..696d600348 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.E using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.Utilities.Commands; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -24,7 +25,7 @@ public class PostUserCommand( IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, - IPaymentService paymentService, + IStripePaymentService paymentService, IScimContext scimContext, IFeatureService featureService, IInviteOrganizationUsersCommand inviteOrganizationUsersCommand, diff --git a/bitwarden_license/src/Scim/appsettings.Development.json b/bitwarden_license/src/Scim/appsettings.Development.json index 32253a93c1..496d0c075f 100644 --- a/bitwarden_license/src/Scim/appsettings.Development.json +++ b/bitwarden_license/src/Scim/appsettings.Development.json @@ -30,6 +30,7 @@ }, "storage": { "connectionString": "UseDevelopmentStorage=true" - } + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/bitwarden_license/src/Scim/appsettings.json b/bitwarden_license/src/Scim/appsettings.json index dcdfeb3ede..18b7a7ca7b 100644 --- a/bitwarden_license/src/Scim/appsettings.json +++ b/bitwarden_license/src/Scim/appsettings.json @@ -30,9 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index a0842daa34..7141f8429d 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -201,12 +201,15 @@ public class AccountController : Controller returnUrl, state = context.Parameters["state"], userIdentifier = context.Parameters["session_state"], + ssoToken }); } [HttpGet] - public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) + public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier, string ssoToken) { + ValidateSchemeAgainstSsoToken(scheme, ssoToken); + if (string.IsNullOrEmpty(returnUrl)) { returnUrl = "~/"; @@ -235,6 +238,31 @@ public class AccountController : Controller return Challenge(props, scheme); } + /// + /// Validates the scheme (organization ID) against the organization ID found in the ssoToken. + /// + /// The authentication scheme (organization ID) to validate. + /// The SSO token to validate against. + /// Thrown if the scheme (organization ID) does not match the organization ID found in the ssoToken. + private void ValidateSchemeAgainstSsoToken(string scheme, string ssoToken) + { + SsoTokenable tokenable; + + try + { + tokenable = _dataProtector.Unprotect(ssoToken); + } + catch + { + throw new Exception(_i18nService.T("InvalidSsoToken")); + } + + if (!Guid.TryParse(scheme, out var schemeOrgId) || tokenable.OrganizationId != schemeOrgId) + { + throw new Exception(_i18nService.T("SsoOrganizationIdMismatch")); + } + } + [HttpGet] public async Task ExternalCallback() { @@ -651,7 +679,23 @@ public class AccountController : Controller EmailVerified = emailVerified, ApiKey = CoreHelpers.SecureRandomString(30) }; - await _registerUserCommand.RegisterUser(newUser); + + /* + The feature flag is checked here so that we can send the new MJML welcome email templates. + The other organization invites flows have an OrganizationUser allowing the RegisterUserCommand the ability + to fetch the Organization. The old method RegisterUser(User) here does not have that context, so we need + to use a new method RegisterSSOAutoProvisionedUserAsync(User, Organization) to send the correct email. + [PM-28057]: Prefer RegisterSSOAutoProvisionedUserAsync for SSO auto-provisioned users. + TODO: Remove Feature flag: PM-28221 + */ + if (_featureService.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)) + { + await _registerUserCommand.RegisterSSOAutoProvisionedUserAsync(newUser, organization); + } + else + { + await _registerUserCommand.RegisterUser(newUser); + } // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email var twoFactorPolicy = diff --git a/bitwarden_license/src/Sso/Program.cs b/bitwarden_license/src/Sso/Program.cs index 1a8ce6eb88..bac3bb3d13 100644 --- a/bitwarden_license/src/Sso/Program.cs +++ b/bitwarden_license/src/Sso/Program.cs @@ -1,5 +1,4 @@ using Bit.Core.Utilities; -using Serilog; namespace Bit.Sso; @@ -13,19 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= globalSettings.MinLogLevel.SsoSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index 3ae8883ac4..2f83f3dad0 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -100,8 +100,6 @@ public class Startup IdentityModelEventSource.ShowPII = true; } - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/bitwarden_license/src/Sso/appsettings.Development.json b/bitwarden_license/src/Sso/appsettings.Development.json index 8aae281068..8e24d82528 100644 --- a/bitwarden_license/src/Sso/appsettings.Development.json +++ b/bitwarden_license/src/Sso/appsettings.Development.json @@ -24,6 +24,13 @@ "storage": { "connectionString": "UseDevelopmentStorage=true" }, - "developmentDirectory": "../../../dev" + "developmentDirectory": "../../../dev", + "pricingUri": "https://billingpricing.qa.bitwarden.pw", + "mail": { + "smtp": { + "host": "localhost", + "port": 10250 + } + } } } diff --git a/bitwarden_license/src/Sso/appsettings.json b/bitwarden_license/src/Sso/appsettings.json index 73c85044cc..9a5df42f7f 100644 --- a/bitwarden_license/src/Sso/appsettings.json +++ b/bitwarden_license/src/Sso/appsettings.json @@ -13,7 +13,11 @@ "mail": { "sendGridApiKey": "SECRET", "amazonConfigSetName": "Email", - "replyToEmail": "no-reply@bitwarden.com" + "replyToEmail": "no-reply@bitwarden.com", + "smtp": { + "host": "localhost", + "port": 10250 + } }, "identityServer": { "certificateThumbprint": "SECRET" diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 2bb02c3cee..810429d658 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -13,7 +13,7 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -131,7 +131,7 @@ public class RemoveOrganizationFromProviderCommandTests Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -156,7 +156,7 @@ public class RemoveOrganizationFromProviderCommandTests "b@example.com" ]); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId, Arg.Is( + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer"))) .Returns(GetSubscription(organization.GatewaySubscriptionId, organization.GatewayCustomerId)); @@ -164,12 +164,14 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - await stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, + await stripeAdapter.Received(1).UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Email == "a@example.com")); - await stripeAdapter.Received(1).CustomerDeleteDiscountAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); + + await stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30)); @@ -207,7 +209,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -226,7 +228,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -239,14 +241,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -296,7 +298,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -315,7 +317,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -328,14 +330,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -416,7 +418,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; organization.Enabled = false; // Start with a disabled organization - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -434,7 +436,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Any()) + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(new Customer { Id = "customer_id", @@ -444,7 +446,7 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "new_subscription_id" }); diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index e61cf5f97e..7ec11894ad 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -1,17 +1,23 @@ using Bit.Commercial.Core.AdminConsole.Services; using Bit.Commercial.Core.Test.AdminConsole.AutoFixture; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -20,6 +26,7 @@ using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationFixtures; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tokens; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; @@ -99,6 +106,57 @@ public class ProviderServiceTests .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); } + [Theory, BitAutoData] + public async Task CompleteSetupAsync_WithAutoConfirmEnabled_ThrowsUserCannotJoinProviderError(User user, Provider provider, + string key, + TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + var providerBillingService = sutProvider.GetDependency(); + + var customer = new Customer { Id = "customer_id" }; + providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer); + + var subscription = new Subscription { Id = "subscription_id" }; + providerBillingService.SetupSubscription(provider).Returns(subscription); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List { new() { OrganizationId = Guid.NewGuid(), IsProvider = false } }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + sutProvider.Create(); + + var token = protector.Protect( + $"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, + billingAddress)); + + Assert.Equal(new UserCannotJoinProvider().Message, exception.Message); + } + [Theory, BitAutoData] public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) { @@ -578,6 +636,132 @@ public class ProviderServiceTests Assert.Equal(user.Id, pu.UserId); } + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmEnabledAndPolicyExists_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List + { + new() { OrganizationId = Guid.NewGuid(), IsProvider = false } + }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); + + Assert.Equal(new UserCannotJoinProvider().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmEnabledButNoPolicyExists_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement([]); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + // Act + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + + // Assert + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmDisabled_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(false); + + // Act + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + + // Assert + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + + // Verify that policy check was never called when feature flag is disabled + await sutProvider.GetDependency() + .DidNotReceive() + .GetAsync(user.Id); + } + [Theory, BitAutoData] public async Task ConfirmUsersAsync_NoValid( [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, @@ -624,13 +808,131 @@ public class ProviderServiceTests Assert.Equal("Invalid user.", result[2].Item2); } + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabledAndPolicyExists_ReturnsError( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List + { + new() { OrganizationId = Guid.NewGuid(), IsProvider = false } + }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(u1.Id) + .Returns(policyRequirement); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal(new UserCannotJoinProvider().Message, result[0].Item2); + + // Verify user was not confirmed + await providerUserRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabledButNoPolicyExists_Success( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(new List()); + sutProvider.GetDependency() + .GetAsync(u1.Id) + .Returns(policyRequirement); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal("", result[0].Item2); + + // Verify user was confirmed + await providerUserRepository.Received(1).ReplaceAsync(Arg.Is(pu => + pu.Status == ProviderUserStatusType.Confirmed)); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmDisabled_Success( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(false); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal("", result[0].Item2); + + // Verify user was confirmed + await providerUserRepository.Received(1).ReplaceAsync(Arg.Is(pu => + pu.Status == ProviderUserStatusType.Confirmed)); + + // Verify that policy check was never called when feature flag is disabled + await sutProvider.GetDependency() + .DidNotReceive() + .GetAsync(Arg.Any()); + } + [Theory, BitAutoData] public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, SutProvider sutProvider) { - providerUser.Id = default; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(providerUser, default)); + providerUser.Id = Guid.Empty; + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveUserAsync(providerUser, Guid.Empty)); Assert.Equal("Invite the user first.", exception.Message); } @@ -756,7 +1058,7 @@ public class ProviderServiceTests await organizationRepository.Received(1) .ReplaceAsync(Arg.Is(org => org.BillingEmail == provider.BillingEmail)); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync( organization.GatewayCustomerId, Arg.Is(options => options.Email == provider.BillingEmail)); @@ -811,12 +1113,12 @@ public class ProviderServiceTests organization.Plan = "Enterprise (Monthly)"; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var expectedPlanType = PlanType.EnterpriseMonthly2020; sutProvider.GetDependency().GetPlanOrThrow(expectedPlanType) - .Returns(StaticStore.GetPlan(expectedPlanType)); + .Returns(MockPlans.Get(expectedPlanType)); var expectedPlanId = "2020-enterprise-org-seat-monthly"; @@ -827,9 +1129,9 @@ public class ProviderServiceTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var subscriptionItem = GetSubscription(organization.GatewaySubscriptionId); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(GetSubscription(organization.GatewaySubscriptionId)); - await sutProvider.GetDependency().SubscriptionUpdateAsync( + await sutProvider.GetDependency().UpdateSubscriptionAsync( organization.GatewaySubscriptionId, SubscriptionUpdateRequest(expectedPlanId, subscriptionItem)); await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs index a7f896ef7a..96dbacfa92 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -63,7 +62,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -95,7 +94,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -129,7 +128,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(false); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -163,7 +162,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -224,7 +223,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "GB" }] @@ -257,7 +256,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -296,7 +295,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -338,7 +337,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -383,7 +382,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -428,7 +427,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -461,7 +460,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) .Returns(new StripeList { Data = [ @@ -470,7 +469,7 @@ public class GetProviderWarningsQueryTests new Registration { Country = "FR" } ] }); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -505,7 +504,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -543,7 +542,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "US" }] diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs index ec52650097..48b971a032 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs @@ -18,6 +18,7 @@ using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.DataProtection; @@ -72,7 +73,7 @@ public class BusinessUnitConverterTests { organization.PlanType = PlanType.EnterpriseAnnually2020; - var enterpriseAnnually2020 = StaticStore.GetPlan(PlanType.EnterpriseAnnually2020); + var enterpriseAnnually2020 = MockPlans.Get(PlanType.EnterpriseAnnually2020); var subscription = new Subscription { @@ -134,7 +135,7 @@ public class BusinessUnitConverterTests _pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2020) .Returns(enterpriseAnnually2020); - var enterpriseAnnually = StaticStore.GetPlan(PlanType.EnterpriseAnnually); + var enterpriseAnnually = MockPlans.Get(PlanType.EnterpriseAnnually); _pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually) .Returns(enterpriseAnnually); @@ -143,11 +144,11 @@ public class BusinessUnitConverterTests await businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey); - await _stripeAdapter.Received(2).CustomerUpdateAsync(subscription.CustomerId, Arg.Any()); + await _stripeAdapter.Received(2).UpdateCustomerAsync(subscription.CustomerId, Arg.Any()); var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, enterpriseAnnually.Type); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(subscription.Id, Arg.Is( + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(subscription.Id, Arg.Is( arguments => arguments.Items.Count == 2 && arguments.Items[0].Id == "subscription_item_id" && @@ -242,7 +243,7 @@ public class BusinessUnitConverterTests argument.Status == ProviderStatusType.Pending && argument.Type == ProviderType.BusinessUnit)).Returns(provider); - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 18c71364e6..93ce33edc4 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -20,9 +20,8 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -85,7 +84,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -113,7 +112,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -140,7 +139,7 @@ public class ProviderBillingServiceTests .Returns(existingPlan); sutProvider.GetDependency().GetPlanOrThrow(existingPlan.PlanType) - .Returns(StaticStore.GetPlan(existingPlan.PlanType)); + .Returns(MockPlans.Get(existingPlan.PlanType)); sutProvider.GetDependency().GetSubscriptionOrThrow(provider) .Returns(new Subscription @@ -155,7 +154,7 @@ public class ProviderBillingServiceTests Id = "si_ent_annual", Price = new Price { - Id = StaticStore.GetPlan(PlanType.EnterpriseAnnually).PasswordManager + Id = MockPlans.Get(PlanType.EnterpriseAnnually).PasswordManager .StripeProviderPortalSeatPlanId }, Quantity = 10 @@ -168,7 +167,7 @@ public class ProviderBillingServiceTests new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly); sutProvider.GetDependency().GetPlanOrThrow(command.NewPlan) - .Returns(StaticStore.GetPlan(command.NewPlan)); + .Returns(MockPlans.Get(command.NewPlan)); // Act await sutProvider.Sut.ChangePlan(command); @@ -180,14 +179,14 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => si.Id == "si_ent_annual" && si.Deleted == true) == 1)); - var newPlanCfg = StaticStore.GetPlan(command.NewPlan); + var newPlanCfg = MockPlans.Get(command.NewPlan); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => @@ -268,7 +267,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -288,7 +287,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -349,7 +348,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -370,7 +369,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -491,7 +490,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); @@ -514,7 +513,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 50 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -535,7 +534,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync( + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpdateSubscriptionAsync( Arg.Any(), Arg.Any()); @@ -573,7 +572,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -598,7 +597,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 95 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -619,7 +618,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -661,7 +660,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -686,7 +685,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 110 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -707,7 +706,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -749,7 +748,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -774,7 +773,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 110 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -795,7 +794,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -827,13 +826,13 @@ public class ProviderBillingServiceTests } ]); - sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(StaticStore.GetPlan(planType)); + sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(MockPlans.Get(planType)); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ new ProviderOrganizationOrganizationDetails { - Plan = StaticStore.GetPlan(planType).Name, + Plan = MockPlans.Get(planType).Name, Status = OrganizationStatusType.Managed, Seats = 5 } @@ -865,13 +864,13 @@ public class ProviderBillingServiceTests } ]); - sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(StaticStore.GetPlan(planType)); + sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(MockPlans.Get(planType)); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ new ProviderOrganizationOrganizationDetails { - Plan = StaticStore.GetPlan(planType).Name, + Plan = MockPlans.Get(planType).Name, Status = OrganizationStatusType.Managed, Seats = 15 } @@ -914,12 +913,12 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -942,7 +941,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); - await stripeAdapter.Received(1).SetupIntentCancel("setup_intent_id", Arg.Is(options => + await stripeAdapter.Received(1).CancelSetupIntentAsync("setup_intent_id", Arg.Is(options => options.CancellationReason == "abandoned")); await sutProvider.GetDependency().Received(1).RemoveSetupIntentForSubscriber(provider.Id); @@ -964,7 +963,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1007,12 +1006,12 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1058,7 +1057,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1100,7 +1099,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1142,7 +1141,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1178,7 +1177,7 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Any()) + stripeAdapter.CreateCustomerAsync(Arg.Any()) .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } }); var actual = await Assert.ThrowsAsync(async () => @@ -1216,7 +1215,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1238,13 +1237,13 @@ public class ProviderBillingServiceTests .Returns(providerPlans); sutProvider.GetDependency().GetPlanOrThrow(PlanType.EnterpriseMonthly) - .Returns(StaticStore.GetPlan(PlanType.EnterpriseMonthly)); + .Returns(MockPlans.Get(PlanType.EnterpriseMonthly)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1266,13 +1265,13 @@ public class ProviderBillingServiceTests .Returns(providerPlans); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly) - .Returns(StaticStore.GetPlan(PlanType.TeamsMonthly)); + .Returns(MockPlans.Get(PlanType.TeamsMonthly)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1317,13 +1316,13 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()) + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Any()) .Returns( new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); @@ -1373,7 +1372,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1381,7 +1380,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && @@ -1449,7 +1448,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1458,7 +1457,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1525,7 +1524,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1538,7 +1537,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntentId); - sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is(options => + sutProvider.GetDependency().GetSetupIntentAsync(setupIntentId, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(new SetupIntent { Id = setupIntentId, @@ -1553,7 +1552,7 @@ public class ProviderBillingServiceTests } }); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1626,7 +1625,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1635,7 +1634,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1704,7 +1703,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1713,7 +1712,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1772,8 +1771,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1806,7 +1805,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1828,7 +1827,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 20 && providerPlan.PurchasedSeats == 5)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1852,8 +1851,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1886,7 +1885,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1908,7 +1907,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 50)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1932,8 +1931,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1966,7 +1965,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1989,7 +1988,7 @@ public class ProviderBillingServiceTests providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 60 && providerPlan.PurchasedSeats == 10)); await stripeAdapter.DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -2006,8 +2005,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -2040,7 +2039,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -2062,7 +2061,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 80 && providerPlan.PurchasedSeats == 0)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -2086,8 +2085,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -2120,7 +2119,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -2142,7 +2141,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.DidNotReceive().ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 1 && @@ -2151,4 +2150,151 @@ public class ProviderBillingServiceTests } #endregion + + #region UpdateProviderNameAndEmail + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullGatewayCustomerId_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.GatewayCustomerId = null; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_EmptyGatewayCustomerId_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.GatewayCustomerId = ""; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullProviderName_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = null; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_EmptyProviderName_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = ""; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_ValidProvider_CallsStripeWithCorrectParameters( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = "Test Provider"; + provider.BillingEmail = "billing@test.com"; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.Email == provider.BillingEmail && + options.Description == provider.Name && + options.InvoiceSettings.CustomFields.Count == 1 && + options.InvoiceSettings.CustomFields[0].Name == "Provider" && + options.InvoiceSettings.CustomFields[0].Value == provider.Name)); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_LongProviderName_UsesFullName( + Provider provider, + SutProvider sutProvider) + { + // Arrange + var longName = new string('A', 50); // 50 characters + provider.Name = longName; + provider.BillingEmail = "billing@test.com"; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.InvoiceSettings.CustomFields[0].Value == longName)); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullBillingEmail_UpdatesWithNull( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = "Test Provider"; + provider.BillingEmail = null; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.Email == null && + options.Description == provider.Name)); + } + + #endregion } diff --git a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs index 16ae8f7f2c..776403fdd5 100644 --- a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs @@ -6,7 +6,7 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Settings; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -69,7 +69,7 @@ public class MaxProjectsQueryTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetPlan(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1); @@ -114,7 +114,7 @@ public class MaxProjectsQueryTests .Returns(projects); sutProvider.GetDependency().GetPlan(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd); diff --git a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs new file mode 100644 index 0000000000..659a6d1233 --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs @@ -0,0 +1,130 @@ +using Bit.Core.SecretsManager.Entities; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Commercial.Core.Test.SecretsManager.Repositories; + +public class SecretVersionRepositoryTests +{ + [Theory] + [BitAutoData] + public void SecretVersion_EntityCreation_Success(SecretVersion secretVersion) + { + // Arrange & Act + secretVersion.SetNewId(); + + // Assert + Assert.NotEqual(Guid.Empty, secretVersion.Id); + Assert.NotEqual(Guid.Empty, secretVersion.SecretId); + Assert.NotNull(secretVersion.Value); + Assert.NotEqual(default, secretVersion.VersionDate); + } + + [Theory] + [BitAutoData] + public void SecretVersion_WithServiceAccountEditor_Success(SecretVersion secretVersion, Guid serviceAccountId) + { + // Arrange & Act + secretVersion.EditorServiceAccountId = serviceAccountId; + secretVersion.EditorOrganizationUserId = null; + + // Assert + Assert.Equal(serviceAccountId, secretVersion.EditorServiceAccountId); + Assert.Null(secretVersion.EditorOrganizationUserId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_WithOrganizationUserEditor_Success(SecretVersion secretVersion, Guid organizationUserId) + { + // Arrange & Act + secretVersion.EditorOrganizationUserId = organizationUserId; + secretVersion.EditorServiceAccountId = null; + + // Assert + Assert.Equal(organizationUserId, secretVersion.EditorOrganizationUserId); + Assert.Null(secretVersion.EditorServiceAccountId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_NullableEditors_Success(SecretVersion secretVersion) + { + // Arrange & Act + secretVersion.EditorServiceAccountId = null; + secretVersion.EditorOrganizationUserId = null; + + // Assert + Assert.Null(secretVersion.EditorServiceAccountId); + Assert.Null(secretVersion.EditorOrganizationUserId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_VersionDateSet_Success(SecretVersion secretVersion) + { + // Arrange + var versionDate = DateTime.UtcNow; + + // Act + secretVersion.VersionDate = versionDate; + + // Assert + Assert.Equal(versionDate, secretVersion.VersionDate); + } + + [Theory] + [BitAutoData] + public void SecretVersion_ValueEncrypted_Success(SecretVersion secretVersion, string encryptedValue) + { + // Arrange & Act + secretVersion.Value = encryptedValue; + + // Assert + Assert.Equal(encryptedValue, secretVersion.Value); + Assert.NotEmpty(secretVersion.Value); + } + + [Theory] + [BitAutoData] + public void SecretVersion_MultipleVersions_DifferentIds(List secretVersions, Guid secretId) + { + // Arrange & Act + foreach (var version in secretVersions) + { + version.SecretId = secretId; + version.SetNewId(); + } + + // Assert + var distinctIds = secretVersions.Select(v => v.Id).Distinct(); + Assert.Equal(secretVersions.Count, distinctIds.Count()); + Assert.All(secretVersions, v => Assert.Equal(secretId, v.SecretId)); + } + + [Theory] + [BitAutoData] + public void SecretVersion_VersionDateOrdering_Success(SecretVersion version1, SecretVersion version2, SecretVersion version3, Guid secretId) + { + // Arrange + var now = DateTime.UtcNow; + version1.SecretId = secretId; + version1.VersionDate = now.AddDays(-2); + + version2.SecretId = secretId; + version2.VersionDate = now.AddDays(-1); + + version3.SecretId = secretId; + version3.VersionDate = now; + + var versions = new List { version2, version3, version1 }; + + // Act + var orderedVersions = versions.OrderByDescending(v => v.VersionDate).ToList(); + + // Assert + Assert.Equal(version3.Id, orderedVersions[0].Id); // Most recent + Assert.Equal(version2.Id, orderedVersions[1].Id); + Assert.Equal(version1.Id, orderedVersions[2].Id); // Oldest + } +} diff --git a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs index 0fe37d89fd..b276174814 100644 --- a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs +++ b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs @@ -3,12 +3,15 @@ using System.Security.Claims; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; +using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Tokens; using Bit.Sso.Controllers; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -18,6 +21,7 @@ using Duende.IdentityServer.Models; using Duende.IdentityServer.Services; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.DependencyInjection; using NSubstitute; @@ -1008,4 +1012,256 @@ public class AccountControllerTest _output.WriteLine($"Scenario={scenario} | OFF: SSO={offCounts.UserGetBySso}, Email={offCounts.UserGetByEmail}, Org={offCounts.OrgGetById}, OrgUserByOrg={offCounts.OrgUserGetByOrg}, OrgUserByEmail={offCounts.OrgUserGetByEmail}"); } } + + [Theory, BitAutoData] + public async Task AutoProvisionUserAsync_WithFeatureFlagEnabled_CallsRegisterSSOAutoProvisionedUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-new-user"; + var email = "newuser@example.com"; + var organization = new Organization { Id = orgId, Name = "Test Org", Seats = null }; + + // No existing user (JIT provisioning scenario) + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + + // Feature flag enabled + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + // Mock the RegisterSSOAutoProvisionedUserAsync to return success + sutProvider.GetDependency() + .RegisterSSOAutoProvisionedUserAsync(Arg.Any(), Arg.Any()) + .Returns(IdentityResult.Success); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "New User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method!.Invoke( + sutProvider.Sut, + new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var result = await task; + + // Assert + await sutProvider.GetDependency().Received(1) + .RegisterSSOAutoProvisionedUserAsync( + Arg.Is(u => u.Email == email && u.Name == "New User"), + Arg.Is(o => o.Id == orgId && o.Name == "Test Org")); + + Assert.NotNull(result.user); + Assert.Equal(email, result.user.Email); + Assert.Equal(organization.Id, result.organization.Id); + } + + [Theory, BitAutoData] + public async Task AutoProvisionUserAsync_WithFeatureFlagDisabled_CallsRegisterUserInstead( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-legacy-user"; + var email = "legacyuser@example.com"; + var organization = new Organization { Id = orgId, Name = "Test Org", Seats = null }; + + // No existing user (JIT provisioning scenario) + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + + // Feature flag disabled + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(false); + + // Mock the RegisterUser to return success + sutProvider.GetDependency() + .RegisterUser(Arg.Any()) + .Returns(IdentityResult.Success); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Legacy User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method!.Invoke( + sutProvider.Sut, + new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var result = await task; + + // Assert + await sutProvider.GetDependency().Received(1) + .RegisterUser(Arg.Is(u => u.Email == email && u.Name == "Legacy User")); + + // Verify the new method was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RegisterSSOAutoProvisionedUserAsync(Arg.Any(), Arg.Any()); + + Assert.NotNull(result.user); + Assert.Equal(email, result.user.Email); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMatchingOrgId_Succeeds( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var orgId = organization.Id; + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with matching org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock URL helper for IsLocalUrl check + var urlHelper = Substitute.For(); + urlHelper.IsLocalUrl(returnUrl).Returns(true); + sutProvider.Sut.Url = urlHelper; + + // Mock interaction service for IsValidReturnUrl check + var interactionService = sutProvider.GetDependency(); + interactionService.IsValidReturnUrl(returnUrl).Returns(true); + + // Act + var result = sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken); + + // Assert + var challengeResult = Assert.IsType(result); + Assert.Contains(scheme, challengeResult.AuthenticationSchemes); + Assert.NotNull(challengeResult.Properties); + Assert.Equal(scheme, challengeResult.Properties.Items["scheme"]); + Assert.Equal(returnUrl, challengeResult.Properties.Items["return_url"]); + Assert.Equal(state, challengeResult.Properties.Items["state"]); + Assert.Equal(userIdentifier, challengeResult.Properties.Items["user_identifier"]); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMismatchedOrgId_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var correctOrgId = organization.Id; + var wrongOrgId = Guid.NewGuid(); + var scheme = wrongOrgId.ToString(); // Different from tokenable's org ID + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with different org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); // Contains correctOrgId + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSchemeFormat_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var scheme = "not-a-valid-guid"; + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a valid tokenable + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSsoToken_ThrowsInvalidSsoToken( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "invalid-corrupted-token"; + + // Mock the data protector to throw when trying to unprotect + var dataProtector = sutProvider.GetDependency>(); + dataProtector.Unprotect(ssoToken).Returns(_ => throw new Exception("Token validation failed")); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("InvalidSsoToken", ex.Message); + } } diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs index 5f562a30c5..9ad231a63d 100644 --- a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs @@ -200,6 +200,38 @@ public class GroupsControllerTests : IClassFixture, IAsy AssertHelper.AssertPropertyEqual(expectedResponse, responseModel); } + [Fact] + public async Task GetList_SearchDisplayNameWithoutOptionalParameters_Success() + { + string filter = "displayName eq Test Group 2"; + int? itemsPerPage = null; + int? startIndex = null; + var expectedResponse = new ScimListResponseModel + { + ItemsPerPage = 50, //default value + TotalResults = 1, + StartIndex = 1, //default value + Resources = new List + { + new ScimGroupResponseModel + { + Id = ScimApplicationFactory.TestGroupId2, + DisplayName = "Test Group 2", + ExternalId = "B", + Schemas = new List { ScimConstants.Scim2SchemaGroup } + } + }, + Schemas = new List { ScimConstants.Scim2SchemaListResponse } + }; + + var context = await _factory.GroupsGetListAsync(ScimApplicationFactory.TestOrganizationId1, filter, itemsPerPage, startIndex); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + var responseModel = JsonSerializer.Deserialize>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + AssertHelper.AssertPropertyEqual(expectedResponse, responseModel); + } + [Fact] public async Task Post_Success() { diff --git a/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json b/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json new file mode 100644 index 0000000000..496d0c075f --- /dev/null +++ b/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json @@ -0,0 +1,36 @@ +{ + "globalSettings": { + "baseServiceUri": { + "vault": "https://localhost:8080", + "api": "http://localhost:4000", + "identity": "http://localhost:33656", + "admin": "http://localhost:62911", + "notifications": "http://localhost:61840", + "sso": "http://localhost:51822", + "internalNotifications": "http://localhost:61840", + "internalAdmin": "http://localhost:62911", + "internalIdentity": "http://localhost:33656", + "internalApi": "http://localhost:4000", + "internalVault": "https://localhost:8080", + "internalSso": "http://localhost:51822", + "internalScim": "http://localhost:44559" + }, + "mail": { + "smtp": { + "host": "localhost", + "port": 10250 + } + }, + "attachment": { + "connectionString": "UseDevelopmentStorage=true", + "baseUrl": "http://localhost:4000/attachments/" + }, + "events": { + "connectionString": "UseDevelopmentStorage=true" + }, + "storage": { + "connectionString": "UseDevelopmentStorage=true" + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" + } +} diff --git a/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs b/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs index 1599b6e390..b835e1fe6b 100644 --- a/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs +++ b/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Scim.Groups; +using Bit.Scim.Models; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -24,7 +25,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, null, count, startIndex); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Count = count, StartIndex = startIndex }); AssertHelper.AssertPropertyEqual(groups.Skip(startIndex - 1).Take(count).ToList(), result.groupList); AssertHelper.AssertPropertyEqual(groups.Count, result.totalResults); @@ -47,7 +48,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -67,7 +68,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -90,7 +91,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -112,7 +113,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); diff --git a/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs b/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs index 9352e5c202..7424b50c0d 100644 --- a/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs @@ -1,5 +1,6 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Scim.Models; using Bit.Scim.Users; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; diff --git a/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs b/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs index f391c93fe3..8b6c850c6f 100644 --- a/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs @@ -1,6 +1,6 @@ using System.Text.Json; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs index ac23e7ecc1..eb8804cac5 100644 --- a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -36,7 +37,7 @@ public class PostUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); sutProvider.GetDependency() .InviteUserAsync(organizationId, diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index c5e42cf9e3..3554306ddb 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -57,7 +57,6 @@ services: mysql: image: mysql:8.0 - container_name: bw-mysql ports: - "3306:3306" command: @@ -88,7 +87,6 @@ services: idp: image: kenchan0130/simplesamlphp:1.19.8 - container_name: idp ports: - "8090:8080" environment: @@ -102,7 +100,6 @@ services: rabbitmq: image: rabbitmq:4.1.3-management - container_name: rabbitmq ports: - "5672:5672" - "15672:15672" @@ -116,7 +113,6 @@ services: reverse-proxy: image: nginx:alpine - container_name: reverse-proxy volumes: - "./reverse-proxy.conf:/etc/nginx/conf.d/default.conf" ports: @@ -126,7 +122,6 @@ services: - proxy service-bus: - container_name: service-bus image: mcr.microsoft.com/azure-messaging/servicebus-emulator:latest pull_policy: always volumes: @@ -142,7 +137,6 @@ services: redis: image: redis:alpine - container_name: bw-redis ports: - "6379:6379" volumes: diff --git a/dev/generate_openapi_files.ps1 b/dev/generate_openapi_files.ps1 index 9eca7dc734..011319b3a3 100644 --- a/dev/generate_openapi_files.ps1 +++ b/dev/generate_openapi_files.ps1 @@ -18,11 +18,11 @@ if ($LASTEXITCODE -ne 0) { # Api internal & public Set-Location "../../src/Api" dotnet build -dotnet swagger tofile --output "../../api.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "internal" +dotnet swagger tofile --output "../../api.json" "./bin/Debug/net8.0/Api.dll" "internal" if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } -dotnet swagger tofile --output "../../api.public.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "public" +dotnet swagger tofile --output "../../api.public.json" "./bin/Debug/net8.0/Api.dll" "public" if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } diff --git a/dev/secrets.json.example b/dev/secrets.json.example index c6a16846e9..0d4213aec1 100644 --- a/dev/secrets.json.example +++ b/dev/secrets.json.example @@ -33,6 +33,10 @@ "id": "", "key": "" }, + "events": { + "connectionString": "", + "queueName": "event" + }, "licenseDirectory": "", "enableNewDeviceVerification": true, "enableEmailVerification": true diff --git a/dev/verify_migrations.ps1 b/dev/verify_migrations.ps1 new file mode 100644 index 0000000000..d63c34f2bd --- /dev/null +++ b/dev/verify_migrations.ps1 @@ -0,0 +1,132 @@ +#!/usr/bin/env pwsh + +<# +.SYNOPSIS + Validates that new database migration files follow naming conventions and chronological order. + +.DESCRIPTION + This script validates migration files in util/Migrator/DbScripts/ to ensure: + 1. New migrations follow the naming format: YYYY-MM-DD_NN_Description.sql + 2. New migrations are chronologically ordered (filename sorts after existing migrations) + 3. Dates use leading zeros (e.g., 2025-01-05, not 2025-1-5) + 4. A 2-digit sequence number is included (e.g., _00, _01) + +.PARAMETER BaseRef + The base git reference to compare against (e.g., 'main', 'HEAD~1') + +.PARAMETER CurrentRef + The current git reference (defaults to 'HEAD') + +.EXAMPLE + # For pull requests - compare against main branch + .\verify_migrations.ps1 -BaseRef main + +.EXAMPLE + # For pushes - compare against previous commit + .\verify_migrations.ps1 -BaseRef HEAD~1 +#> + +param( + [Parameter(Mandatory = $true)] + [string]$BaseRef, + + [Parameter(Mandatory = $false)] + [string]$CurrentRef = "HEAD" +) + +# Use invariant culture for consistent string comparison +[System.Threading.Thread]::CurrentThread.CurrentCulture = [System.Globalization.CultureInfo]::InvariantCulture + +$migrationPath = "util/Migrator/DbScripts" + +# Get list of migrations from base reference +try { + $baseMigrations = git ls-tree -r --name-only $BaseRef -- "$migrationPath/*.sql" 2>$null | Sort-Object + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Could not retrieve migrations from base reference '$BaseRef'" + $baseMigrations = @() + } +} +catch { + Write-Host "Warning: Could not retrieve migrations from base reference '$BaseRef'" + $baseMigrations = @() +} + +# Get list of migrations from current reference +$currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$migrationPath/*.sql" | Sort-Object + +# Find added migrations +$addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } + +if ($addedMigrations.Count -eq 0) { + Write-Host "No new migration files added." + exit 0 +} + +Write-Host "New migration files detected:" +$addedMigrations | ForEach-Object { Write-Host " $_" } +Write-Host "" + +# Get the last migration from base reference +if ($baseMigrations.Count -eq 0) { + Write-Host "No previous migrations found (initial commit?). Skipping validation." + exit 0 +} + +$lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) +Write-Host "Last migration in base reference: $lastBaseMigration" +Write-Host "" + +# Required format regex: YYYY-MM-DD_NN_Description.sql +$formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' + +$validationFailed = $false + +foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Validate NEW migration filename format + if ($migrationName -notmatch $formatRegex) { + Write-Host "ERROR: Migration '$migrationName' does not match required format" + Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" + Write-Host " - YYYY: 4-digit year" + Write-Host " - MM: 2-digit month with leading zero (01-12)" + Write-Host " - DD: 2-digit day with leading zero (01-31)" + Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" + Write-Host "Example: 2025-01-15_00_MyMigration.sql" + $validationFailed = $true + continue + } + + # Compare migration name with last base migration (using ordinal string comparison) + if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" + $validationFailed = $true + } + else { + Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + } +} + +Write-Host "" + +if ($validationFailed) { + Write-Host "FAILED: One or more migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new migration files must:" + Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" + Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" + Write-Host " 4. Have a filename that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" + Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 3. Ensure the date is after $lastBaseMigration" + Write-Host "" + Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + exit 1 +} + +Write-Host "SUCCESS: All new migrations are correctly named and in chronological order" +exit 0 diff --git a/global.json b/global.json index d25197db39..4cbe3f083a 100644 --- a/global.json +++ b/global.json @@ -5,6 +5,7 @@ }, "msbuild-sdks": { "Microsoft.Build.Traversal": "4.1.0", - "Microsoft.Build.Sql": "1.0.0" + "Microsoft.Build.Sql": "1.0.0", + "Bitwarden.Server.Sdk": "1.2.0" } } diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 0d992cb96a..cd370e3898 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -14,8 +14,10 @@ using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -41,7 +43,7 @@ public class OrganizationsController : Controller private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; private readonly IPolicyRepository _policyRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IApplicationCacheService _applicationCacheService; private readonly GlobalSettings _globalSettings; private readonly IProviderRepository _providerRepository; @@ -56,6 +58,7 @@ public class OrganizationsController : Controller private readonly IOrganizationInitiateDeleteCommand _organizationInitiateDeleteCommand; private readonly IPricingClient _pricingClient; private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand; + private readonly IOrganizationBillingService _organizationBillingService; public OrganizationsController( IOrganizationRepository organizationRepository, @@ -66,7 +69,7 @@ public class OrganizationsController : Controller ICollectionRepository collectionRepository, IGroupRepository groupRepository, IPolicyRepository policyRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IApplicationCacheService applicationCacheService, GlobalSettings globalSettings, IProviderRepository providerRepository, @@ -80,7 +83,8 @@ public class OrganizationsController : Controller IProviderBillingService providerBillingService, IOrganizationInitiateDeleteCommand organizationInitiateDeleteCommand, IPricingClient pricingClient, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IOrganizationBillingService organizationBillingService) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -105,6 +109,7 @@ public class OrganizationsController : Controller _organizationInitiateDeleteCommand = organizationInitiateDeleteCommand; _pricingClient = pricingClient; _resendOrganizationInviteCommand = resendOrganizationInviteCommand; + _organizationBillingService = organizationBillingService; } [RequirePermission(Permission.Org_List_View)] @@ -241,6 +246,8 @@ public class OrganizationsController : Controller var existingOrganizationData = new Organization { Id = organization.Id, + Name = organization.Name, + BillingEmail = organization.BillingEmail, Status = organization.Status, PlanType = organization.PlanType, Seats = organization.Seats @@ -286,6 +293,22 @@ public class OrganizationsController : Controller await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + // Sync name/email changes to Stripe + if (existingOrganizationData.Name != organization.Name || existingOrganizationData.BillingEmail != organization.BillingEmail) + { + try + { + await _organizationBillingService.UpdateOrganizationNameAndEmail(organization); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for organization {OrganizationId}. Database was updated successfully.", + organization.Id); + TempData["Warning"] = "Organization updated successfully, but Stripe customer name/email synchronization failed."; + } + } + return RedirectToAction("Edit", new { id }); } @@ -473,6 +496,7 @@ public class OrganizationsController : Controller organization.UseOrganizationDomains = model.UseOrganizationDomains; organization.UseAdminSponsoredFamilies = model.UseAdminSponsoredFamilies; organization.UseAutomaticUserConfirmation = model.UseAutomaticUserConfirmation; + organization.UsePhishingBlocker = model.UsePhishingBlocker; //secrets organization.SmSeats = model.SmSeats; diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index 9344179a77..d9135e1d1c 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -56,6 +56,7 @@ public class ProvidersController : Controller private readonly IStripeAdapter _stripeAdapter; private readonly IAccessControlService _accessControlService; private readonly ISubscriberService _subscriberService; + private readonly ILogger _logger; public ProvidersController(IOrganizationRepository organizationRepository, IResellerClientOrganizationSignUpCommand resellerClientOrganizationSignUpCommand, @@ -72,7 +73,8 @@ public class ProvidersController : Controller IPricingClient pricingClient, IStripeAdapter stripeAdapter, IAccessControlService accessControlService, - ISubscriberService subscriberService) + ISubscriberService subscriberService, + ILogger logger) { _organizationRepository = organizationRepository; _resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand; @@ -92,6 +94,7 @@ public class ProvidersController : Controller _braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl(); _braintreeMerchantId = globalSettings.Braintree.MerchantId; _subscriberService = subscriberService; + _logger = logger; } [RequirePermission(Permission.Provider_List_View)] @@ -296,6 +299,9 @@ public class ProvidersController : Controller var originalProviderStatus = provider.Enabled; + // Capture original billing email before modifications for Stripe sync + var originalBillingEmail = provider.BillingEmail; + model.ToProvider(provider); // validate the stripe ids to prevent saving a bad one @@ -321,6 +327,22 @@ public class ProvidersController : Controller await _providerService.UpdateAsync(provider); await _applicationCacheService.UpsertProviderAbilityAsync(provider); + // Sync billing email changes to Stripe + if (!string.IsNullOrEmpty(provider.GatewayCustomerId) && originalBillingEmail != provider.BillingEmail) + { + try + { + await _providerBillingService.UpdateProviderNameAndEmail(provider); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.", + provider.Id); + TempData["Warning"] = "Provider updated successfully, but Stripe customer email synchronization failed."; + } + } + if (!provider.IsBillable()) { return RedirectToAction("Edit", new { id }); @@ -339,11 +361,11 @@ public class ProvidersController : Controller ]); await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand); - var customer = await _stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId); if (model.PayByInvoice != customer.ApprovedToPayByInvoice()) { var approvedToPayByInvoice = model.PayByInvoice ? "1" : "0"; - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = new Dictionary { diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs index 6059a003b6..4fff85e1e8 100644 --- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs @@ -107,6 +107,7 @@ public class OrganizationEditModel : OrganizationViewModel MaxAutoscaleSmServiceAccounts = org.MaxAutoscaleSmServiceAccounts; UseOrganizationDomains = org.UseOrganizationDomains; UseAutomaticUserConfirmation = org.UseAutomaticUserConfirmation; + UsePhishingBlocker = org.UsePhishingBlocker; _plans = plans; } @@ -160,6 +161,8 @@ public class OrganizationEditModel : OrganizationViewModel public new bool UseSecretsManager { get; set; } [Display(Name = "Risk Insights")] public new bool UseRiskInsights { get; set; } + [Display(Name = "Phishing Blocker")] + public new bool UsePhishingBlocker { get; set; } [Display(Name = "Admin Sponsored Families")] public bool UseAdminSponsoredFamilies { get; set; } [Display(Name = "Self Host")] @@ -327,6 +330,7 @@ public class OrganizationEditModel : OrganizationViewModel existingOrganization.SmServiceAccounts = SmServiceAccounts; existingOrganization.MaxAutoscaleSmServiceAccounts = MaxAutoscaleSmServiceAccounts; existingOrganization.UseOrganizationDomains = UseOrganizationDomains; + existingOrganization.UsePhishingBlocker = UsePhishingBlocker; return existingOrganization; } } diff --git a/src/Admin/AdminConsole/Models/OrganizationViewModel.cs b/src/Admin/AdminConsole/Models/OrganizationViewModel.cs index 2c126ecd8e..457686be53 100644 --- a/src/Admin/AdminConsole/Models/OrganizationViewModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationViewModel.cs @@ -75,6 +75,7 @@ public class OrganizationViewModel public int OccupiedSmSeatsCount { get; set; } public bool UseSecretsManager => Organization.UseSecretsManager; public bool UseRiskInsights => Organization.UseRiskInsights; + public bool UsePhishingBlocker => Organization.UsePhishingBlocker; public IEnumerable OwnersDetails { get; set; } public IEnumerable AdminsDetails { get; set; } } diff --git a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml index cb71c0fc78..b22859ed60 100644 --- a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml +++ b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml @@ -156,6 +156,10 @@ +
+ + +
@if(FeatureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) {
diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index 46dafd65e7..2dd6de89a0 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -8,6 +8,7 @@ using Bit.Admin.Utilities; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index b85a91719c..f42b22b098 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -5,6 +5,7 @@ using Bit.Admin.Models; using Bit.Admin.Services; using Bit.Admin.Utilities; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -20,7 +21,7 @@ public class UsersController : Controller { private readonly IUserRepository _userRepository; private readonly ICipherRepository _cipherRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly GlobalSettings _globalSettings; private readonly IAccessControlService _accessControlService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; @@ -30,7 +31,7 @@ public class UsersController : Controller public UsersController( IUserRepository userRepository, ICipherRepository cipherRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, GlobalSettings globalSettings, IAccessControlService accessControlService, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, diff --git a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs index 434c265f26..219e6846bd 100644 --- a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs +++ b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs @@ -19,7 +19,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable public virtual async Task StartAsync(CancellationToken cancellationToken) { // Wait 20 seconds to allow database to come online - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); var maxMigrationAttempts = 10; for (var i = 1; i <= maxMigrationAttempts; i++) @@ -41,7 +41,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable { _logger.LogError(e, "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); } } } diff --git a/src/Admin/Program.cs b/src/Admin/Program.cs index 05bf35d41d..006a8223b2 100644 --- a/src/Admin/Program.cs +++ b/src/Admin/Program.cs @@ -16,19 +16,8 @@ public class Program o.Limits.MaxRequestLineSize = 20_000; }); webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= globalSettings.MinLogLevel.AdminSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index 5ecbdc899c..87d68a7ac6 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -132,11 +132,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Admin/appsettings.Development.json b/src/Admin/appsettings.Development.json index 861f9be98d..15d61f493f 100644 --- a/src/Admin/appsettings.Development.json +++ b/src/Admin/appsettings.Development.json @@ -27,6 +27,7 @@ }, "storage": { "connectionString": "UseDevelopmentStorage=true" - } + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs b/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs new file mode 100644 index 0000000000..9b147c3c54 --- /dev/null +++ b/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs @@ -0,0 +1,26 @@ +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.Models.Api; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.AdminConsole.Controllers; + +public abstract class BaseAdminConsoleController : Controller +{ + protected static IResult Handle(CommandResult commandResult) => + commandResult.Match( + error => error switch + { + BadRequestError badRequest => TypedResults.BadRequest(new ErrorResponseModel(badRequest.Message)), + NotFoundError notFound => TypedResults.NotFound(new ErrorResponseModel(notFound.Message)), + InternalError internalError => TypedResults.Json( + new ErrorResponseModel(internalError.Message), + statusCode: StatusCodes.Status500InternalServerError), + _ => TypedResults.Json( + new ErrorResponseModel(error.Message), + statusCode: StatusCodes.Status500InternalServerError + ) + }, + _ => TypedResults.NoContent() + ); +} diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs index 0b7fe8dffe..f172a23529 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs @@ -1,8 +1,8 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; using Bit.Core.Context; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -12,8 +12,10 @@ namespace Bit.Api.AdminConsole.Controllers; [Authorize("Application")] public class OrganizationIntegrationConfigurationController( ICurrentContext currentContext, - IOrganizationIntegrationRepository integrationRepository, - IOrganizationIntegrationConfigurationRepository integrationConfigurationRepository) : Controller + ICreateOrganizationIntegrationConfigurationCommand createCommand, + IUpdateOrganizationIntegrationConfigurationCommand updateCommand, + IDeleteOrganizationIntegrationConfigurationCommand deleteCommand, + IGetOrganizationIntegrationConfigurationsQuery getQuery) : Controller { [HttpGet("")] public async Task> GetAsync( @@ -24,13 +26,8 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - var configurations = await integrationConfigurationRepository.GetManyByIntegrationAsync(integrationId); + var configurations = await getQuery.GetManyByIntegrationAsync(organizationId, integrationId); return configurations .Select(configuration => new OrganizationIntegrationConfigurationResponseModel(configuration)) .ToList(); @@ -46,19 +43,11 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - var organizationIntegrationConfiguration = model.ToOrganizationIntegrationConfiguration(integrationId); - var configuration = await integrationConfigurationRepository.CreateAsync(organizationIntegrationConfiguration); - return new OrganizationIntegrationConfigurationResponseModel(configuration); + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var created = await createCommand.CreateAsync(organizationId, integrationId, configuration); + + return new OrganizationIntegrationConfigurationResponseModel(created); } [HttpPut("{configurationId:guid}")] @@ -72,26 +61,11 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var updated = await updateCommand.UpdateAsync(organizationId, integrationId, configurationId, configuration); - var newConfiguration = model.ToOrganizationIntegrationConfiguration(configuration); - await integrationConfigurationRepository.ReplaceAsync(newConfiguration); - - return new OrganizationIntegrationConfigurationResponseModel(newConfiguration); + return new OrganizationIntegrationConfigurationResponseModel(updated); } [HttpDelete("{configurationId:guid}")] @@ -101,19 +75,8 @@ public class OrganizationIntegrationConfigurationController( { throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } - - await integrationConfigurationRepository.DeleteAsync(configuration); + await deleteCommand.DeleteAsync(organizationId, integrationId, configurationId); } [HttpPost("{configurationId:guid}/delete")] diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs index 181811e892..b82fe3dfa8 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs @@ -1,8 +1,8 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; using Bit.Core.Context; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -12,7 +12,10 @@ namespace Bit.Api.AdminConsole.Controllers; [Authorize("Application")] public class OrganizationIntegrationController( ICurrentContext currentContext, - IOrganizationIntegrationRepository integrationRepository) : Controller + ICreateOrganizationIntegrationCommand createCommand, + IUpdateOrganizationIntegrationCommand updateCommand, + IDeleteOrganizationIntegrationCommand deleteCommand, + IGetOrganizationIntegrationsQuery getQuery) : Controller { [HttpGet("")] public async Task> GetAsync(Guid organizationId) @@ -22,7 +25,7 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integrations = await getQuery.GetManyByOrganizationAsync(organizationId); return integrations .Select(integration => new OrganizationIntegrationResponseModel(integration)) .ToList(); @@ -36,8 +39,10 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.CreateAsync(model.ToOrganizationIntegration(organizationId)); - return new OrganizationIntegrationResponseModel(integration); + var integration = model.ToOrganizationIntegration(organizationId); + var created = await createCommand.CreateAsync(integration); + + return new OrganizationIntegrationResponseModel(created); } [HttpPut("{integrationId:guid}")] @@ -48,14 +53,10 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration is null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } + var integration = model.ToOrganizationIntegration(organizationId); + var updated = await updateCommand.UpdateAsync(organizationId, integrationId, integration); - await integrationRepository.ReplaceAsync(model.ToOrganizationIntegration(integration)); - return new OrganizationIntegrationResponseModel(integration); + return new OrganizationIntegrationResponseModel(updated); } [HttpDelete("{integrationId:guid}")] @@ -66,13 +67,7 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration is null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - - await integrationRepository.DeleteAsync(integration); + await deleteCommand.DeleteAsync(organizationId, integrationId); } [HttpPost("{integrationId:guid}/delete")] diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 4b9f7e5d71..a380d2f0d9 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -11,8 +11,10 @@ using Bit.Api.Models.Response; using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -20,6 +22,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Pricing; @@ -38,12 +41,14 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; +using V1_RevokeOrganizationUserCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1.IRevokeOrganizationUserCommand; +using V2_RevokeOrganizationUserCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; namespace Bit.Api.AdminConsole.Controllers; [Route("organizations/{orgId}/users")] [Authorize("Application")] -public class OrganizationUsersController : Controller +public class OrganizationUsersController : BaseAdminConsoleController { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; @@ -68,10 +73,13 @@ public class OrganizationUsersController : Controller private readonly IFeatureService _featureService; private readonly IPricingClient _pricingClient; private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand; + private readonly IBulkResendOrganizationInvitesCommand _bulkResendOrganizationInvitesCommand; + private readonly IAutomaticallyConfirmOrganizationUserCommand _automaticallyConfirmOrganizationUserCommand; + private readonly V2_RevokeOrganizationUserCommand.IRevokeOrganizationUserCommand _revokeOrganizationUserCommandVNext; private readonly IConfirmOrganizationUserCommand _confirmOrganizationUserCommand; private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; private readonly IInitPendingOrganizationCommand _initPendingOrganizationCommand; - private readonly IRevokeOrganizationUserCommand _revokeOrganizationUserCommand; + private readonly V1_RevokeOrganizationUserCommand _revokeOrganizationUserCommand; private readonly IAdminRecoverAccountCommand _adminRecoverAccountCommand; public OrganizationUsersController(IOrganizationRepository organizationRepository, @@ -99,9 +107,12 @@ public class OrganizationUsersController : Controller IConfirmOrganizationUserCommand confirmOrganizationUserCommand, IRestoreOrganizationUserCommand restoreOrganizationUserCommand, IInitPendingOrganizationCommand initPendingOrganizationCommand, - IRevokeOrganizationUserCommand revokeOrganizationUserCommand, + V1_RevokeOrganizationUserCommand revokeOrganizationUserCommand, IResendOrganizationInviteCommand resendOrganizationInviteCommand, - IAdminRecoverAccountCommand adminRecoverAccountCommand) + IBulkResendOrganizationInvitesCommand bulkResendOrganizationInvitesCommand, + IAdminRecoverAccountCommand adminRecoverAccountCommand, + IAutomaticallyConfirmOrganizationUserCommand automaticallyConfirmOrganizationUserCommand, + V2_RevokeOrganizationUserCommand.IRevokeOrganizationUserCommand revokeOrganizationUserCommandVNext) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -126,6 +137,9 @@ public class OrganizationUsersController : Controller _featureService = featureService; _pricingClient = pricingClient; _resendOrganizationInviteCommand = resendOrganizationInviteCommand; + _bulkResendOrganizationInvitesCommand = bulkResendOrganizationInvitesCommand; + _automaticallyConfirmOrganizationUserCommand = automaticallyConfirmOrganizationUserCommand; + _revokeOrganizationUserCommandVNext = revokeOrganizationUserCommandVNext; _confirmOrganizationUserCommand = confirmOrganizationUserCommand; _restoreOrganizationUserCommand = restoreOrganizationUserCommand; _initPendingOrganizationCommand = initPendingOrganizationCommand; @@ -267,7 +281,17 @@ public class OrganizationUsersController : Controller public async Task> BulkReinvite(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ResendInvitesAsync(orgId, userId.Value, model.Ids); + + IEnumerable> result; + if (_featureService.IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud)) + { + result = await _bulkResendOrganizationInvitesCommand.BulkResendInvitesAsync(orgId, userId.Value, model.Ids); + } + else + { + result = await _organizationService.ResendInvitesAsync(orgId, userId.Value, model.Ids); + } + return new ListResponseModel( result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); } @@ -477,43 +501,10 @@ public class OrganizationUsersController : Controller } } +#nullable enable [HttpPut("{id}/reset-password")] [Authorize] public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) - { - if (_featureService.IsEnabled(FeatureFlagKeys.AccountRecoveryCommand)) - { - // TODO: remove legacy implementation after feature flag is enabled. - return await PutResetPasswordNew(orgId, id, model); - } - - // Get the users role, since provider users aren't a member of the organization we use the owner check - var orgUserType = await _currentContext.OrganizationOwner(orgId) - ? OrganizationUserType.Owner - : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgId)?.Type; - if (orgUserType == null) - { - return TypedResults.NotFound(); - } - - var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgId, id, model.NewMasterPasswordHash, model.Key); - if (result.Succeeded) - { - return TypedResults.Ok(); - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - return TypedResults.BadRequest(ModelState); - } - -#nullable enable - // TODO: make sure the route and authorize attributes are maintained when the legacy implementation is removed. - private async Task PutResetPasswordNew(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) { var targetOrganizationUser = await _organizationUserRepository.GetByIdAsync(id); if (targetOrganizationUser == null || targetOrganizationUser.OrganizationId != orgId) @@ -656,7 +647,29 @@ public class OrganizationUsersController : Controller [Authorize] public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - return await RestoreOrRevokeUsersAsync(orgId, model, _revokeOrganizationUserCommand.RevokeUsersAsync); + if (!_featureService.IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2)) + { + return await RestoreOrRevokeUsersAsync(orgId, model, _revokeOrganizationUserCommand.RevokeUsersAsync); + } + + var currentUserId = _userService.GetProperUserId(User); + if (currentUserId == null) + { + throw new UnauthorizedAccessException(); + } + + var results = await _revokeOrganizationUserCommandVNext.RevokeUsersAsync( + new V2_RevokeOrganizationUserCommand.RevokeOrganizationUsersRequest( + orgId, + model.Ids.ToArray(), + new StandardUser(currentUserId.Value, await _currentContext.OrganizationOwner(orgId)))); + + return new ListResponseModel(results + .Select(result => new OrganizationUserBulkResponseModel(result.Id, + result.Result.Match( + error => error.Message, + _ => string.Empty + )))); } [HttpPatch("revoke")] @@ -738,6 +751,31 @@ public class OrganizationUsersController : Controller await BulkEnableSecretsManagerAsync(orgId, model); } + [HttpPost("{id}/auto-confirm")] + [Authorize] + [RequireFeature(FeatureFlagKeys.AutomaticConfirmUsers)] + public async Task AutomaticallyConfirmOrganizationUserAsync([FromRoute] Guid orgId, + [FromRoute] Guid id, + [FromBody] OrganizationUserConfirmRequestModel model) + { + var userId = _userService.GetProperUserId(User); + + if (userId is null || userId.Value == Guid.Empty) + { + return TypedResults.Unauthorized(); + } + + return Handle(await _automaticallyConfirmOrganizationUserCommand.AutomaticallyConfirmOrganizationUserAsync( + new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationId = orgId, + OrganizationUserId = id, + Key = model.Key, + DefaultUserCollectionName = model.DefaultUserCollectionName, + PerformedBy = new StandardUser(userId.Value, await _currentContext.OrganizationOwner(orgId)), + })); + } + private async Task RestoreOrRevokeUserAsync( Guid orgId, Guid id, diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 590895665d..100cd7caf6 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -12,7 +12,6 @@ using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Request.Organizations; using Bit.Api.Models.Response; using Bit.Core; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; @@ -70,6 +69,7 @@ public class OrganizationsController : Controller private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPricingClient _pricingClient; private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand; + private readonly IOrganizationUpdateCommand _organizationUpdateCommand; public OrganizationsController( IOrganizationRepository organizationRepository, @@ -94,7 +94,8 @@ public class OrganizationsController : Controller IOrganizationDeleteCommand organizationDeleteCommand, IPolicyRequirementQuery policyRequirementQuery, IPricingClient pricingClient, - IOrganizationUpdateKeysCommand organizationUpdateKeysCommand) + IOrganizationUpdateKeysCommand organizationUpdateKeysCommand, + IOrganizationUpdateCommand organizationUpdateCommand) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -119,6 +120,7 @@ public class OrganizationsController : Controller _policyRequirementQuery = policyRequirementQuery; _pricingClient = pricingClient; _organizationUpdateKeysCommand = organizationUpdateKeysCommand; + _organizationUpdateCommand = organizationUpdateCommand; } [HttpGet("{id}")] @@ -224,36 +226,31 @@ public class OrganizationsController : Controller return new OrganizationResponseModel(result.Organization, plan); } - [HttpPut("{id}")] - public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) + [HttpPut("{organizationId:guid}")] + public async Task Put(Guid organizationId, [FromBody] OrganizationUpdateRequestModel model) { - var orgIdGuid = new Guid(id); + // If billing email is being changed, require subscription editing permissions. + // Otherwise, organization owner permissions are sufficient. + var requiresBillingPermission = model.BillingEmail is not null; + var authorized = requiresBillingPermission + ? await _currentContext.EditSubscription(organizationId) + : await _currentContext.OrganizationOwner(organizationId); - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) + if (!authorized) { - throw new NotFoundException(); + return TypedResults.Unauthorized(); } - var updateBilling = ShouldUpdateBilling(model, organization); + var commandRequest = model.ToCommandRequest(organizationId); + var updatedOrganization = await _organizationUpdateCommand.UpdateAsync(commandRequest); - var hasRequiredPermissions = updateBilling - ? await _currentContext.EditSubscription(orgIdGuid) - : await _currentContext.OrganizationOwner(orgIdGuid); - - if (!hasRequiredPermissions) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); - var plan = await _pricingClient.GetPlan(organization.PlanType); - return new OrganizationResponseModel(organization, plan); + var plan = await _pricingClient.GetPlan(updatedOrganization.PlanType); + return TypedResults.Ok(new OrganizationResponseModel(updatedOrganization, plan)); } [HttpPost("{id}")] [Obsolete("This endpoint is deprecated. Use PUT method instead")] - public async Task PostPut(string id, [FromBody] OrganizationUpdateRequestModel model) + public async Task PostPut(Guid id, [FromBody] OrganizationUpdateRequestModel model) { return await Put(id, model); } @@ -588,11 +585,4 @@ public class OrganizationsController : Controller return organization.PlanType; } - - private bool ShouldUpdateBilling(OrganizationUpdateRequestModel model, Organization organization) - { - var organizationNameChanged = model.Name != organization.Name; - var billingEmailChanged = model.BillingEmail != organization.BillingEmail; - return !_globalSettings.SelfHosted && (organizationNameChanged || billingEmailChanged); - } } diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index a5272413e2..ae1d12e887 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -42,7 +42,6 @@ public class PoliciesController : Controller private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IPolicyRepository _policyRepository; private readonly IUserService _userService; - private readonly IFeatureService _featureService; private readonly ISavePolicyCommand _savePolicyCommand; private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; @@ -55,7 +54,6 @@ public class PoliciesController : Controller IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IOrganizationRepository organizationRepository, - IFeatureService featureService, ISavePolicyCommand savePolicyCommand, IVNextSavePolicyCommand vNextSavePolicyCommand) { @@ -69,7 +67,6 @@ public class PoliciesController : Controller _organizationRepository = organizationRepository; _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; - _featureService = featureService; _savePolicyCommand = savePolicyCommand; _vNextSavePolicyCommand = vNextSavePolicyCommand; } @@ -221,9 +218,7 @@ public class PoliciesController : Controller { var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, type, _currentContext); - var policy = _featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) ? - await _vNextSavePolicyCommand.SaveAsync(savePolicyRequest) : - await _savePolicyCommand.VNextSaveAsync(savePolicyRequest); + var policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyRequest); return new PolicyResponseModel(policy); } diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index aa87bf9c74..515404e8a9 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -5,6 +5,7 @@ using Bit.Api.AdminConsole.Models.Request.Providers; using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Services; @@ -23,15 +24,20 @@ public class ProvidersController : Controller private readonly IProviderService _providerService; private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; + private readonly IProviderBillingService _providerBillingService; + private readonly ILogger _logger; public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, + IProviderBillingService providerBillingService, ILogger logger) { _userService = userService; _providerRepository = providerRepository; _providerService = providerService; _currentContext = currentContext; _globalSettings = globalSettings; + _providerBillingService = providerBillingService; + _logger = logger; } [HttpGet("{id:guid}")] @@ -65,7 +71,27 @@ public class ProvidersController : Controller throw new NotFoundException(); } + // Capture original values before modifications for Stripe sync + var originalName = provider.Name; + var originalBillingEmail = provider.BillingEmail; + await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); + + // Sync name/email changes to Stripe + if (originalName != provider.Name || originalBillingEmail != provider.BillingEmail) + { + try + { + await _providerBillingService.UpdateProviderNameAndEmail(provider); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.", + provider.Id); + } + } + return new ProviderResponseModel(provider); } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs index 8581c4ae1f..9341392d68 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs @@ -1,6 +1,4 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; @@ -16,38 +14,6 @@ public class OrganizationIntegrationConfigurationRequestModel public string? Template { get; set; } - public bool IsValidForType(IntegrationType integrationType) - { - switch (integrationType) - { - case IntegrationType.CloudBillingSync or IntegrationType.Scim: - return false; - case IntegrationType.Slack: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Webhook: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Hec: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - case IntegrationType.Datadog: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - case IntegrationType.Teams: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - default: - return false; - - } - } - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(Guid organizationIntegrationId) { return new OrganizationIntegrationConfiguration() @@ -59,50 +25,4 @@ public class OrganizationIntegrationConfigurationRequestModel Template = Template }; } - - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(OrganizationIntegrationConfiguration currentConfiguration) - { - currentConfiguration.Configuration = Configuration; - currentConfiguration.EventType = EventType; - currentConfiguration.Filters = Filters; - currentConfiguration.Template = Template; - - return currentConfiguration; - } - - private bool IsConfigurationValid() - { - if (string.IsNullOrWhiteSpace(Configuration)) - { - return false; - } - - try - { - var config = JsonSerializer.Deserialize(Configuration); - return config is not null; - } - catch - { - return false; - } - } - - private bool IsFiltersValid() - { - if (Filters is null) - { - return true; - } - - try - { - var filters = JsonSerializer.Deserialize(Filters); - return filters is not null; - } - catch - { - return false; - } - } } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs index 5a3192c121..6c3867fe09 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs @@ -1,41 +1,28 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Models.Data; -using Bit.Core.Settings; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request.Organizations; public class OrganizationUpdateRequestModel { - [Required] [StringLength(50, ErrorMessage = "The field Name exceeds the maximum length.")] [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } - [StringLength(50, ErrorMessage = "The field Business Name exceeds the maximum length.")] - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string BusinessName { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } - public Permissions Permissions { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } + public string? Name { get; set; } - public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) + [EmailAddress] + [StringLength(256)] + public string? BillingEmail { get; set; } + + public OrganizationKeysRequestModel? Keys { get; set; } + + public OrganizationUpdateRequest ToCommandRequest(Guid organizationId) => new() { - if (!globalSettings.SelfHosted) - { - // These items come from the license file - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - } - Keys?.ToOrganization(existingOrganization); - return existingOrganization; - } + OrganizationId = organizationId, + Name = Name, + BillingEmail = BillingEmail, + PublicKey = Keys?.PublicKey, + EncryptedPrivateKey = Keys?.EncryptedPrivateKey + }; } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs index 4e0accb9e8..b7a4db3acd 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -119,7 +119,7 @@ public class OrganizationUserResetPasswordEnrollmentRequestModel public class OrganizationUserBulkRequestModel { - [Required] + [Required, MinLength(1)] public IEnumerable Ids { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs index c172c45e94..f5ef468b4e 100644 --- a/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs @@ -47,6 +47,7 @@ public abstract class BaseProfileOrganizationResponseModel : ResponseModel UseAdminSponsoredFamilies = organizationDetails.UseAdminSponsoredFamilies; UseAutomaticUserConfirmation = organizationDetails.UseAutomaticUserConfirmation; UseSecretsManager = organizationDetails.UseSecretsManager; + UsePhishingBlocker = organizationDetails.UsePhishingBlocker; UsePasswordManager = organizationDetails.UsePasswordManager; SelfHost = organizationDetails.SelfHost; Seats = organizationDetails.Seats; @@ -99,6 +100,7 @@ public abstract class BaseProfileOrganizationResponseModel : ResponseModel public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } public bool SelfHost { get; set; } public int? Seats { get; set; } public short? MaxCollections { get; set; } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs index 8006a85734..9a3543f4bb 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs @@ -1,10 +1,13 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using System.Security.Claims; using System.Text.Json.Serialization; using Bit.Api.Models.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Models.Api; using Bit.Core.Models.Business; @@ -71,6 +74,7 @@ public class OrganizationResponseModel : ResponseModel UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; + UsePhishingBlocker = organization.UsePhishingBlocker; } public Guid Id { get; set; } @@ -120,6 +124,7 @@ public class OrganizationResponseModel : ResponseModel public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } public class OrganizationSubscriptionResponseModel : OrganizationResponseModel @@ -175,6 +180,30 @@ public class OrganizationSubscriptionResponseModel : OrganizationResponseModel } } + public OrganizationSubscriptionResponseModel(Organization organization, OrganizationLicense license, ClaimsPrincipal claimsPrincipal) : + this(organization, (Plan)null) + { + if (license != null) + { + // CRITICAL: When a license has a Token (JWT), ALWAYS use the expiration from the token claim + // The token's expiration is cryptographically secured and cannot be tampered with + // The file's Expires property can be manually edited and should NOT be trusted for display + if (claimsPrincipal != null) + { + Expiration = claimsPrincipal.GetValue(OrganizationLicenseConstants.Expires); + ExpirationWithoutGracePeriod = claimsPrincipal.GetValue(OrganizationLicenseConstants.ExpirationWithoutGracePeriod); + } + else + { + // No token - use the license file expiration (for older licenses without tokens) + Expiration = license.Expires; + ExpirationWithoutGracePeriod = license.ExpirationWithoutGracePeriod ?? (license.Trial + ? license.Expires + : license.Expires?.AddDays(-Constants.OrganizationSelfHostSubscriptionGracePeriodDays)); + } + } + } + public string StorageName { get; set; } public double? StorageGb { get; set; } public BillingCustomerDiscount CustomerDiscount { get; set; } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs index 81ca801308..0507de7a55 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs @@ -30,6 +30,7 @@ public class PolicyResponseModel : ResponseModel { Data = JsonSerializer.Deserialize>(policy.Data); } + RevisionDate = policy.RevisionDate; } public Guid Id { get; set; } @@ -37,4 +38,5 @@ public class PolicyResponseModel : ResponseModel public PolicyType Type { get; set; } public Dictionary Data { get; set; } public bool Enabled { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs index 97a58d038a..8c52092dae 100644 --- a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs @@ -1,4 +1,5 @@ -using Bit.Core.Enums; +using Bit.Core.Billing.Models; +using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; @@ -27,7 +28,7 @@ public class ProfileOrganizationResponseModel : BaseProfileOrganizationResponseM FamilySponsorshipToDelete = organizationDetails.FamilySponsorshipToDelete; FamilySponsorshipValidUntil = organizationDetails.FamilySponsorshipValidUntil; FamilySponsorshipAvailable = (organizationDetails.FamilySponsorshipFriendlyName == null || IsAdminInitiated) && - StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) + SponsoredPlans.Get(PlanSponsorshipType.FamiliesForEnterprise) .UsersCanSponsor(organizationDetails); AccessSecretsManager = organizationDetails.AccessSecretsManager; } diff --git a/src/Api/AdminConsole/Public/Controllers/EventsController.cs b/src/Api/AdminConsole/Public/Controllers/EventsController.cs index 19edbdd5a6..b92e576ef9 100644 --- a/src/Api/AdminConsole/Public/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Public/Controllers/EventsController.cs @@ -1,6 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - + using System.Net; using Bit.Api.Models.Public.Request; using Bit.Api.Models.Public.Response; @@ -8,6 +6,7 @@ using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.Context; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Vault.Repositories; using Microsoft.AspNetCore.Authorization; @@ -22,6 +21,9 @@ public class EventsController : Controller private readonly IEventRepository _eventRepository; private readonly ICipherRepository _cipherRepository; private readonly ICurrentContext _currentContext; + private readonly ISecretRepository _secretRepository; + private readonly IProjectRepository _projectRepository; + private readonly IUserService _userService; private readonly ILogger _logger; private readonly IFeatureService _featureService; @@ -29,12 +31,18 @@ public class EventsController : Controller IEventRepository eventRepository, ICipherRepository cipherRepository, ICurrentContext currentContext, + ISecretRepository secretRepository, + IProjectRepository projectRepository, + IUserService userService, ILogger logger, IFeatureService featureService) { _eventRepository = eventRepository; _cipherRepository = cipherRepository; _currentContext = currentContext; + _secretRepository = secretRepository; + _projectRepository = projectRepository; + _userService = userService; _logger = logger; _featureService = featureService; } @@ -50,35 +58,76 @@ public class EventsController : Controller [ProducesResponseType(typeof(PagedListResponseModel), (int)HttpStatusCode.OK)] public async Task List([FromQuery] EventFilterRequestModel request) { + if (!_currentContext.OrganizationId.HasValue) + { + return new JsonResult(new PagedListResponseModel([], "")); + } + + var organizationId = _currentContext.OrganizationId.Value; var dateRange = request.ToDateRange(); var result = new PagedResult(); if (request.ActingUserId.HasValue) { result = await _eventRepository.GetManyByOrganizationActingUserAsync( - _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, + organizationId, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = request.ContinuationToken }); } else if (request.ItemId.HasValue) { var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); - if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) + if (cipher != null && cipher.OrganizationId == organizationId) { result = await _eventRepository.GetManyByCipherAsync( cipher, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = request.ContinuationToken }); } } + else if (request.SecretId.HasValue) + { + var secret = await _secretRepository.GetByIdAsync(request.SecretId.Value); + + if (secret == null) + { + secret = new Core.SecretsManager.Entities.Secret { Id = request.SecretId.Value, OrganizationId = organizationId }; + } + + if (secret.OrganizationId == organizationId) + { + result = await _eventRepository.GetManyBySecretAsync( + secret, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else + { + return new JsonResult(new PagedListResponseModel([], "")); + } + } + else if (request.ProjectId.HasValue) + { + var project = await _projectRepository.GetByIdAsync(request.ProjectId.Value); + if (project != null && project.OrganizationId == organizationId) + { + result = await _eventRepository.GetManyByProjectAsync( + project, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else + { + return new JsonResult(new PagedListResponseModel([], "")); + } + } else { result = await _eventRepository.GetManyByOrganizationAsync( - _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, + organizationId, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = request.ContinuationToken }); } var eventResponses = result.Data.Select(e => new EventResponseModel(e)); - var response = new PagedListResponseModel(eventResponses, result.ContinuationToken); + var response = new PagedListResponseModel(eventResponses, result.ContinuationToken ?? ""); + + _logger.LogAggregateData(_featureService, organizationId, response, request); - _logger.LogAggregateData(_featureService, _currentContext.OrganizationId!.Value, response, request); return new JsonResult(response); } } diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 3b2e82121d..58e5db18c2 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; using Bit.Core.Services; @@ -24,7 +25,7 @@ public class MembersController : Controller private readonly ICurrentContext _currentContext; private readonly IUpdateOrganizationUserCommand _updateOrganizationUserCommand; private readonly IUpdateOrganizationUserGroupsCommand _updateOrganizationUserGroupsCommand; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationRepository _organizationRepository; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; @@ -37,7 +38,7 @@ public class MembersController : Controller ICurrentContext currentContext, IUpdateOrganizationUserCommand updateOrganizationUserCommand, IUpdateOrganizationUserGroupsCommand updateOrganizationUserGroupsCommand, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IRemoveOrganizationUserCommand removeOrganizationUserCommand, diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs index be0997f271..cf8da813be 100644 --- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs @@ -5,15 +5,10 @@ using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; -using Bit.Core; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Context; -using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -24,25 +19,16 @@ namespace Bit.Api.AdminConsole.Public.Controllers; public class PoliciesController : Controller { private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; private readonly ICurrentContext _currentContext; - private readonly IFeatureService _featureService; - private readonly ISavePolicyCommand _savePolicyCommand; private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public PoliciesController( IPolicyRepository policyRepository, - IPolicyService policyService, ICurrentContext currentContext, - IFeatureService featureService, - ISavePolicyCommand savePolicyCommand, IVNextSavePolicyCommand vNextSavePolicyCommand) { _policyRepository = policyRepository; - _policyService = policyService; _currentContext = currentContext; - _featureService = featureService; - _savePolicyCommand = savePolicyCommand; _vNextSavePolicyCommand = vNextSavePolicyCommand; } @@ -97,17 +83,8 @@ public class PoliciesController : Controller [ProducesResponseType((int)HttpStatusCode.NotFound)] public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) { - Policy policy; - if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) - { - var savePolicyModel = model.ToSavePolicyModel(_currentContext.OrganizationId!.Value, type); - policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyModel); - } - else - { - var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type); - policy = await _savePolicyCommand.SaveAsync(policyUpdate); - } + var savePolicyModel = model.ToSavePolicyModel(_currentContext.OrganizationId!.Value, type); + var policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyModel); var response = new PolicyResponseModel(policy); return new JsonResult(response); diff --git a/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs b/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs index 2d96425d55..a007349f26 100644 --- a/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs +++ b/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs @@ -24,6 +24,14 @@ public class EventFilterRequestModel /// public Guid? ItemId { get; set; } /// + /// The unique identifier of the related secret that the event describes. + /// + public Guid? SecretId { get; set; } + /// + /// The unique identifier of the related project that the event describes. + /// + public Guid? ProjectId { get; set; } + /// /// A cursor for use in pagination. /// public string ContinuationToken { get; set; } diff --git a/src/Api/Api.csproj b/src/Api/Api.csproj index 138549e92d..48fedfc8c1 100644 --- a/src/Api/Api.csproj +++ b/src/Api/Api.csproj @@ -33,7 +33,7 @@ - + diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index ecf49c18c8..839d00f7a1 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -18,6 +18,7 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Api.Response; using Bit.Core.Repositories; @@ -44,6 +45,7 @@ public class AccountsController : Controller private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; + private readonly IUserRepository _userRepository; public AccountsController( IOrganizationService organizationService, @@ -57,7 +59,8 @@ public class AccountsController : Controller IFeatureService featureService, IUserAccountKeysQuery userAccountKeysQuery, ITwoFactorEmailService twoFactorEmailService, - IChangeKdfCommand changeKdfCommand + IChangeKdfCommand changeKdfCommand, + IUserRepository userRepository ) { _organizationService = organizationService; @@ -72,6 +75,7 @@ public class AccountsController : Controller _userAccountKeysQuery = userAccountKeysQuery; _twoFactorEmailService = twoFactorEmailService; _changeKdfCommand = changeKdfCommand; + _userRepository = userRepository; } @@ -432,16 +436,36 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - if (_featureService.IsEnabled(FeatureFlagKeys.ReturnErrorOnExistingKeypair)) + if (!string.IsNullOrWhiteSpace(user.PrivateKey) || !string.IsNullOrWhiteSpace(user.PublicKey)) { - if (!string.IsNullOrWhiteSpace(user.PrivateKey) || !string.IsNullOrWhiteSpace(user.PublicKey)) - { - throw new BadRequestException("User has existing keypair"); - } + throw new BadRequestException("User has existing keypair"); + } + + if (model.AccountKeys != null) + { + var accountKeysData = model.AccountKeys.ToAccountKeysData(); + if (!accountKeysData.IsV2Encryption()) + { + throw new BadRequestException("AccountKeys are only supported for V2 encryption."); + } + await _userRepository.SetV2AccountCryptographicStateAsync(user.Id, accountKeysData); + return new KeysResponseModel(accountKeysData, user.Key); + } + else + { + // Todo: Drop this after a transition period. This will drop no-account-keys requests. + // The V1 check in the other branch should persist + // https://bitwarden.atlassian.net/browse/PM-27329 + await _userService.SaveUserAsync(model.ToUser(user)); + return new KeysResponseModel(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( + user.PrivateKey, + user.PublicKey + ) + }, user.Key); } - await _userService.SaveUserAsync(model.ToUser(user)); - return new KeysResponseModel(user); } [HttpGet("keys")] @@ -453,7 +477,8 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - return new KeysResponseModel(user); + var accountKeys = await _userAccountKeysQuery.Run(user); + return new KeysResponseModel(accountKeys, user.Key); } [HttpDelete] diff --git a/src/Api/Auth/Controllers/TwoFactorController.cs b/src/Api/Auth/Controllers/TwoFactorController.cs index 0af46fb57c..ba6cf66859 100644 --- a/src/Api/Auth/Controllers/TwoFactorController.cs +++ b/src/Api/Auth/Controllers/TwoFactorController.cs @@ -9,7 +9,6 @@ using Bit.Api.Models.Response; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Identity.TokenProviders; -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Services; using Bit.Core.Context; @@ -35,7 +34,7 @@ public class TwoFactorController : Controller private readonly IOrganizationService _organizationService; private readonly UserManager _userManager; private readonly ICurrentContext _currentContext; - private readonly IVerifyAuthRequestCommand _verifyAuthRequestCommand; + private readonly IAuthRequestRepository _authRequestRepository; private readonly IDuoUniversalTokenService _duoUniversalTokenService; private readonly IDataProtectorTokenFactory _twoFactorAuthenticatorDataProtector; private readonly IDataProtectorTokenFactory _ssoEmailTwoFactorSessionDataProtector; @@ -47,7 +46,7 @@ public class TwoFactorController : Controller IOrganizationService organizationService, UserManager userManager, ICurrentContext currentContext, - IVerifyAuthRequestCommand verifyAuthRequestCommand, + IAuthRequestRepository authRequestRepository, IDuoUniversalTokenService duoUniversalConfigService, IDataProtectorTokenFactory twoFactorAuthenticatorDataProtector, IDataProtectorTokenFactory ssoEmailTwoFactorSessionDataProtector, @@ -58,7 +57,7 @@ public class TwoFactorController : Controller _organizationService = organizationService; _userManager = userManager; _currentContext = currentContext; - _verifyAuthRequestCommand = verifyAuthRequestCommand; + _authRequestRepository = authRequestRepository; _duoUniversalTokenService = duoUniversalConfigService; _twoFactorAuthenticatorDataProtector = twoFactorAuthenticatorDataProtector; _ssoEmailTwoFactorSessionDataProtector = ssoEmailTwoFactorSessionDataProtector; @@ -350,14 +349,15 @@ public class TwoFactorController : Controller if (user != null) { - // Check if 2FA email is from Passwordless. + // Check if 2FA email is from a device approval ("Log in with device") scenario. if (!string.IsNullOrEmpty(requestModel.AuthRequestAccessCode)) { - if (await _verifyAuthRequestCommand - .VerifyAuthRequestAsync(new Guid(requestModel.AuthRequestId), - requestModel.AuthRequestAccessCode)) + var authRequest = await _authRequestRepository.GetByIdAsync(new Guid(requestModel.AuthRequestId)); + if (authRequest != null && + authRequest.IsValidForAuthentication(user.Id, requestModel.AuthRequestAccessCode)) { await _twoFactorEmailService.SendTwoFactorEmailAsync(user); + return; } } else if (!string.IsNullOrEmpty(requestModel.SsoEmail2FaSessionToken)) diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index 7abcf8c357..243f4d3c53 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -1,7 +1,5 @@ -#nullable enable -using Bit.Api.Billing.Models.Responses; +using Bit.Api.Billing.Models.Responses; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Requests; using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -12,10 +10,11 @@ namespace Bit.Api.Billing.Controllers; [Route("accounts/billing")] [Authorize("Application")] public class AccountsBillingController( - IPaymentService paymentService, + IStripePaymentService paymentService, IUserService userService, IPaymentHistoryService paymentHistoryService) : Controller { + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("history")] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingHistoryAsync() @@ -30,20 +29,7 @@ public class AccountsBillingController( return new BillingHistoryResponseModel(billingInfo); } - [HttpGet("payment-method")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetPaymentMethodAsync() - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var billingInfo = await paymentService.GetBillingAsync(user); - return new BillingPaymentResponseModel(billingInfo); - } - + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null) { @@ -62,6 +48,7 @@ public class AccountsBillingController( return TypedResults.Ok(invoices); } + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("transactions")] public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter = null) { @@ -78,18 +65,4 @@ public class AccountsBillingController( return TypedResults.Ok(transactions); } - - [HttpPost("preview-invoice")] - public async Task PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, user.GatewayCustomerId, user.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } } diff --git a/src/Api/Billing/Controllers/AccountsController.cs b/src/Api/Billing/Controllers/AccountsController.cs index 075218dd74..5d3e095fdd 100644 --- a/src/Api/Billing/Controllers/AccountsController.cs +++ b/src/Api/Billing/Controllers/AccountsController.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Api.Models.Request; +using Bit.Api.Models.Request; using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Response; using Bit.Api.Utilities; @@ -26,8 +24,10 @@ public class AccountsController( IUserService userService, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IUserAccountKeysQuery userAccountKeysQuery, - IFeatureService featureService) : Controller + IFeatureService featureService, + ILicensingService licensingService) : Controller { + // TODO: Remove when pm-24996-implement-upgrade-from-free-dialog is removed [HttpPost("premium")] public async Task PostPremiumAsync( PremiumRequestModel model, @@ -75,10 +75,11 @@ public class AccountsController( }; } + // TODO: Migrate to Query / AccountBillingVNextController as part of Premium -> Organization upgrade work. [HttpGet("subscription")] public async Task GetSubscriptionAsync( [FromServices] GlobalSettings globalSettings, - [FromServices] IPaymentService paymentService) + [FromServices] IStripePaymentService paymentService) { var user = await userService.GetUserByPrincipalAsync(User); if (user == null) @@ -97,12 +98,14 @@ public class AccountsController( var includeMilestone2Discount = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); var subscriptionInfo = await paymentService.GetSubscriptionAsync(user); var license = await userService.GenerateLicenseAsync(user, subscriptionInfo); - return new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license); + return new SubscriptionResponseModel(user, subscriptionInfo, license, claimsPrincipal, includeMilestone2Discount); } else { var license = await userService.GenerateLicenseAsync(user); - return new SubscriptionResponseModel(user, license); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license); + return new SubscriptionResponseModel(user, null, license, claimsPrincipal); } } else @@ -111,29 +114,7 @@ public class AccountsController( } } - [HttpPost("payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPaymentAsync([FromBody] PaymentRequestModel model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType!.Value, - new TaxInfo - { - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressCountry = model.Country, - BillingAddressPostalCode = model.PostalCode, - TaxIdNumber = model.TaxId - }); - } - + // TODO: Migrate to Command / AccountBillingVNextController as PUT /account/billing/vnext/subscription [HttpPost("storage")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostStorageAsync([FromBody] StorageRequestModel model) @@ -148,8 +129,11 @@ public class AccountsController( return new PaymentResponseModel { Success = true, PaymentIntentClientSecret = result }; } - - + /* + * TODO: A new version of this exists in the AccountBillingVNextController. + * The individual-self-hosting-license-uploader.component needs to be updated to use it. + * Then, this can be removed. + */ [HttpPost("license")] [SelfHosted(SelfHostedOnly = true)] public async Task PostLicenseAsync(LicenseRequestModel model) @@ -169,6 +153,7 @@ public class AccountsController( await userService.UpdateLicenseAsync(user, license); } + // TODO: Migrate to Command / AccountBillingVNextController as DELETE /account/billing/vnext/subscription [HttpPost("cancel")] public async Task PostCancelAsync( [FromBody] SubscriptionCancellationRequestModel request, @@ -186,6 +171,7 @@ public class AccountsController( user.IsExpired()); } + // TODO: Migrate to Command / AccountBillingVNextController as POST /account/billing/vnext/subscription/reinstate [HttpPost("reinstate-premium")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostReinstateAsync() @@ -199,41 +185,6 @@ public class AccountsController( await userService.ReinstatePremiumAsync(user); } - [HttpGet("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfoAsync( - [FromServices] IPaymentService paymentService) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = await paymentService.GetTaxInfoAsync(user); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfoAsync( - [FromBody] TaxInfoUpdateRequestModel model, - [FromServices] IPaymentService paymentService) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = new TaxInfo - { - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await paymentService.SaveTaxInfoAsync(user, taxInfo); - } - private async Task> GetOrganizationIdsClaimingUserAsync(Guid userId) { var organizationsClaimingUser = await userService.GetOrganizationsClaimingUserAsync(userId); diff --git a/src/Api/Billing/Controllers/InvoicesController.cs b/src/Api/Billing/Controllers/InvoicesController.cs deleted file mode 100644 index 30ea975e09..0000000000 --- a/src/Api/Billing/Controllers/InvoicesController.cs +++ /dev/null @@ -1,45 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("invoices")] -[Authorize("Application")] -public class InvoicesController : BaseBillingController -{ - [HttpPost("preview-organization")] - public async Task PreviewInvoiceAsync( - [FromBody] PreviewOrganizationInvoiceRequestBody model, - [FromServices] ICurrentContext currentContext, - [FromServices] IOrganizationRepository organizationRepository, - [FromServices] IPaymentService paymentService) - { - Organization organization = null; - if (model.OrganizationId != default) - { - if (!await currentContext.EditPaymentMethods(model.OrganizationId)) - { - return Error.Unauthorized(); - } - - organization = await organizationRepository.GetByIdAsync(model.OrganizationId); - if (organization == null) - { - return Error.NotFound(); - } - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, organization?.GatewayCustomerId, - organization?.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } -} diff --git a/src/Api/Billing/Controllers/LicensesController.cs b/src/Api/Billing/Controllers/LicensesController.cs deleted file mode 100644 index 29313bd4d8..0000000000 --- a/src/Api/Billing/Controllers/LicensesController.cs +++ /dev/null @@ -1,91 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationConnections.Interfaces; -using Bit.Core.Billing.Models.Business; -using Bit.Core.Billing.Organizations.Models; -using Bit.Core.Billing.Organizations.Queries; -using Bit.Core.Context; -using Bit.Core.Exceptions; -using Bit.Core.Models.Api.OrganizationLicenses; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("licenses")] -[Authorize("Licensing")] -[SelfHosted(NotSelfHostedOnly = true)] -public class LicensesController : Controller -{ - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IGetCloudOrganizationLicenseQuery _getCloudOrganizationLicenseQuery; - private readonly IValidateBillingSyncKeyCommand _validateBillingSyncKeyCommand; - private readonly ICurrentContext _currentContext; - - public LicensesController( - IUserRepository userRepository, - IUserService userService, - IOrganizationRepository organizationRepository, - IGetCloudOrganizationLicenseQuery getCloudOrganizationLicenseQuery, - IValidateBillingSyncKeyCommand validateBillingSyncKeyCommand, - ICurrentContext currentContext) - { - _userRepository = userRepository; - _userService = userService; - _organizationRepository = organizationRepository; - _getCloudOrganizationLicenseQuery = getCloudOrganizationLicenseQuery; - _validateBillingSyncKeyCommand = validateBillingSyncKeyCommand; - _currentContext = currentContext; - } - - [HttpGet("user/{id}")] - public async Task GetUser(string id, [FromQuery] string key) - { - var user = await _userRepository.GetByIdAsync(new Guid(id)); - if (user == null) - { - return null; - } - else if (!user.LicenseKey.Equals(key)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); - } - - var license = await _userService.GenerateLicenseAsync(user, null); - return license; - } - - /// - /// Used by self-hosted installations to get an updated license file - /// - [HttpGet("organization/{id}")] - public async Task OrganizationSync(string id, [FromBody] SelfHostedOrganizationLicenseRequestModel model) - { - var organization = await _organizationRepository.GetByIdAsync(new Guid(id)); - if (organization == null) - { - throw new NotFoundException("Organization not found."); - } - - if (!organization.LicenseKey.Equals(model.LicenseKey)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); - } - - if (!await _validateBillingSyncKeyCommand.ValidateBillingSyncKeyAsync(organization, model.BillingSyncKey)) - { - throw new BadRequestException("Invalid Billing Sync Key"); - } - - var license = await _getCloudOrganizationLicenseQuery.GetLicenseAsync(organization, _currentContext.InstallationId.Value); - return license; - } -} diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 6e4cacc155..e06d946ea0 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -5,7 +5,6 @@ using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -19,10 +18,10 @@ public class OrganizationBillingController( ICurrentContext currentContext, IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, - ISubscriberService subscriberService, + IStripePaymentService paymentService, IPaymentHistoryService paymentHistoryService) : BaseBillingController { + // TODO: Remove when pm-25379-use-new-organization-metadata-structure is removed. [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) { @@ -41,6 +40,7 @@ public class OrganizationBillingController( return TypedResults.Ok(metadata); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("history")] public async Task GetHistoryAsync([FromRoute] Guid organizationId) { @@ -61,6 +61,7 @@ public class OrganizationBillingController( return TypedResults.Ok(billingInfo); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid organizationId, [FromQuery] string? status = null, [FromQuery] string? startAfter = null) { @@ -85,6 +86,7 @@ public class OrganizationBillingController( return TypedResults.Ok(invoices); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("transactions")] public async Task GetTransactionsAsync([FromRoute] Guid organizationId, [FromQuery] DateTime? startAfter = null) { @@ -108,6 +110,7 @@ public class OrganizationBillingController( return TypedResults.Ok(transactions); } + // TODO: Can be removed once we do away with the organization-plans.component. [HttpGet] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingAsync(Guid organizationId) @@ -131,127 +134,7 @@ public class OrganizationBillingController( return TypedResults.Ok(response); } - [HttpGet("payment-method")] - public async Task GetPaymentMethodAsync([FromRoute] Guid organizationId) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var paymentMethod = await subscriberService.GetPaymentMethod(organization); - - var response = PaymentMethodResponse.From(paymentMethod); - - return TypedResults.Ok(response); - } - - [HttpPut("payment-method")] - public async Task UpdatePaymentMethodAsync( - [FromRoute] Guid organizationId, - [FromBody] UpdatePaymentMethodRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain(); - - var taxInformation = requestBody.TaxInformation.ToDomain(); - - await organizationBillingService.UpdatePaymentMethod(organization, tokenizedPaymentSource, taxInformation); - - return TypedResults.Ok(); - } - - [HttpPost("payment-method/verify-bank-account")] - public async Task VerifyBankAccountAsync( - [FromRoute] Guid organizationId, - [FromBody] VerifyBankAccountRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM")) - { - return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'"); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - await subscriberService.VerifyBankAccount(organization, requestBody.DescriptorCode); - - return TypedResults.Ok(); - } - - [HttpGet("tax-information")] - public async Task GetTaxInformationAsync([FromRoute] Guid organizationId) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var taxInformation = await subscriberService.GetTaxInformation(organization); - - var response = TaxInformationResponse.From(taxInformation); - - return TypedResults.Ok(response); - } - - [HttpPut("tax-information")] - public async Task UpdateTaxInformationAsync( - [FromRoute] Guid organizationId, - [FromBody] TaxInformationRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var taxInformation = requestBody.ToDomain(); - - await subscriberService.UpdateTaxInformation(organization, taxInformation); - - return TypedResults.Ok(); - } - + // TODO: Migrate to Command / OrganizationBillingVNextController [HttpPost("setup-business-unit")] [SelfHosted(NotSelfHostedOnly = true)] public async Task SetupBusinessUnitAsync( @@ -280,6 +163,7 @@ public class OrganizationBillingController( return TypedResults.Ok(providerId); } + // TODO: Migrate to Command / OrganizationBillingVNextController [HttpPost("change-frequency")] [SelfHosted(NotSelfHostedOnly = true)] public async Task ChangePlanSubscriptionFrequencyAsync( diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index 5494c5a90e..bca5605a8c 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -19,7 +19,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; using Bit.Core.Services; @@ -37,7 +36,7 @@ public class OrganizationsController( IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, IUserService userService, - IPaymentService paymentService, + IStripePaymentService paymentService, ICurrentContext currentContext, IGetCloudOrganizationLicenseQuery getCloudOrganizationLicenseQuery, GlobalSettings globalSettings, @@ -67,7 +66,8 @@ public class OrganizationsController( if (globalSettings.SelfHosted) { var orgLicense = await licensingService.ReadOrganizationLicenseAsync(organization); - return new OrganizationSubscriptionResponseModel(organization, orgLicense); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(orgLicense); + return new OrganizationSubscriptionResponseModel(organization, orgLicense, claimsPrincipal); } var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -248,53 +248,6 @@ public class OrganizationsController( await organizationService.ReinstateSubscriptionAsync(id); } - [HttpGet("{id:guid}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo(Guid id) - { - if (!await currentContext.OrganizationOwner(id)) - { - throw new NotFoundException(); - } - - var organization = await organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = await paymentService.GetTaxInfoAsync(organization); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("{id:guid}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo(Guid id, [FromBody] ExpandedTaxInfoUpdateRequestModel model) - { - if (!await currentContext.OrganizationOwner(id)) - { - throw new NotFoundException(); - } - - var organization = await organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = new TaxInfo - { - TaxIdNumber = model.TaxId, - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await paymentService.SaveTaxInfoAsync(organization, taxInfo); - } - /// /// Tries to grant owner access to the Secrets Manager for the organization /// diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index 006a7ce068..dfa705a329 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,7 +1,6 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Pricing; @@ -9,7 +8,6 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Models.BitStripe; using Bit.Core.Services; @@ -34,6 +32,7 @@ public class ProviderBillingController( IStripeAdapter stripeAdapter, IUserService userService) : BaseProviderController(currentContext, logger, providerRepository, userService) { + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid providerId) { @@ -44,7 +43,7 @@ public class ProviderBillingController( return result; } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = provider.GatewayCustomerId }); @@ -54,6 +53,7 @@ public class ProviderBillingController( return TypedResults.Ok(response); } + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("invoices/{invoiceId}")] public async Task GenerateClientInvoiceReportAsync([FromRoute] Guid providerId, string invoiceId) { @@ -76,51 +76,7 @@ public class ProviderBillingController( "text/csv"); } - [HttpPut("payment-method")] - public async Task UpdatePaymentMethodAsync( - [FromRoute] Guid providerId, - [FromBody] UpdatePaymentMethodRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain(); - var taxInformation = requestBody.TaxInformation.ToDomain(); - - await providerBillingService.UpdatePaymentMethod( - provider, - tokenizedPaymentSource, - taxInformation); - - return TypedResults.Ok(); - } - - [HttpPost("payment-method/verify-bank-account")] - public async Task VerifyBankAccountAsync( - [FromRoute] Guid providerId, - [FromBody] VerifyBankAccountRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM")) - { - return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'"); - } - - await subscriberService.VerifyBankAccount(provider, requestBody.DescriptorCode); - - return TypedResults.Ok(); - } - + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) { @@ -131,7 +87,7 @@ public class ProviderBillingController( return result; } - var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.tax_ids", "discounts", "test_clock"] }); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -140,7 +96,7 @@ public class ProviderBillingController( { var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, plan.Type); - var price = await stripeAdapter.PriceGetAsync(priceId); + var price = await stripeAdapter.GetPriceAsync(priceId); var unitAmount = price.UnitAmountDecimal.HasValue ? price.UnitAmountDecimal.Value / 100M @@ -172,53 +128,4 @@ public class ProviderBillingController( return TypedResults.Ok(response); } - - [HttpGet("tax-information")] - public async Task GetTaxInformationAsync([FromRoute] Guid providerId) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var taxInformation = await subscriberService.GetTaxInformation(provider); - - var response = TaxInformationResponse.From(taxInformation); - - return TypedResults.Ok(response); - } - - [HttpPut("tax-information")] - public async Task UpdateTaxInformationAsync( - [FromRoute] Guid providerId, - [FromBody] TaxInformationRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - if (requestBody is not { Country: not null, PostalCode: not null }) - { - return Error.BadRequest("Country and postal code are required to update your tax information."); - } - - var taxInformation = new TaxInformation( - requestBody.Country, - requestBody.PostalCode, - requestBody.TaxId, - requestBody.TaxIdType, - requestBody.Line1, - requestBody.Line2, - requestBody.City, - requestBody.State); - - await subscriberService.UpdateTaxInformation(provider, taxInformation); - - return TypedResults.Ok(); - } } diff --git a/src/Api/Billing/Controllers/StripeController.cs b/src/Api/Billing/Controllers/StripeController.cs index 15fccd16f4..6cb10e3165 100644 --- a/src/Api/Billing/Controllers/StripeController.cs +++ b/src/Api/Billing/Controllers/StripeController.cs @@ -1,5 +1,5 @@ -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -28,7 +28,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } @@ -43,7 +43,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs similarity index 92% rename from src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs rename to src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs index 973a7d99a1..b86f29bdbc 100644 --- a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs +++ b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Api.Utilities; using Bit.Core; @@ -17,7 +16,7 @@ namespace Bit.Api.Billing.Controllers.VNext; [Authorize("Application")] [Route("account/billing/vnext/self-host")] [SelfHosted(SelfHostedOnly = true)] -public class SelfHostedAccountBillingController( +public class SelfHostedAccountBillingVNextController( ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController { [HttpPost("license")] diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs new file mode 100644 index 0000000000..625a97c998 --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs @@ -0,0 +1,35 @@ +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.AdminConsole.Authorization.Requirements; +using Bit.Api.Billing.Attributes; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; + +namespace Bit.Api.Billing.Controllers.VNext; + +[Authorize("Application")] +[Route("organizations/{organizationId:guid}/billing/vnext/self-host")] +[SelfHosted(SelfHostedOnly = true)] +public class SelfHostedOrganizationBillingVNextController( + IGetOrganizationMetadataQuery getOrganizationMetadataQuery) : BaseBillingController +{ + [Authorize] + [HttpGet("metadata")] + [RequireFeature(FeatureFlagKeys.PM25379_UseNewOrganizationMetadataStructure)] + [InjectOrganization] + public async Task GetMetadataAsync([BindNever] Organization organization) + { + var metadata = await getOrganizationMetadataQuery.Run(organization); + + if (metadata == null) + { + return TypedResults.NotFound(); + } + + return TypedResults.Ok(metadata); + } +} diff --git a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs deleted file mode 100644 index a1b754a9dc..0000000000 --- a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs +++ /dev/null @@ -1,31 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Requests; - -public class TaxInformationRequestBody -{ - [Required] - public string Country { get; set; } - [Required] - public string PostalCode { get; set; } - public string TaxId { get; set; } - public string TaxIdType { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } - - public TaxInformation ToDomain() => new( - Country, - PostalCode, - TaxId, - TaxIdType, - Line1, - Line2, - City, - State); -} diff --git a/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs b/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs deleted file mode 100644 index b469ce2576..0000000000 --- a/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs +++ /dev/null @@ -1,25 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Api.Utilities; -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Api.Billing.Models.Requests; - -public class TokenizedPaymentSourceRequestBody -{ - [Required] - [EnumMatches( - PaymentMethodType.BankAccount, - PaymentMethodType.Card, - PaymentMethodType.PayPal, - ErrorMessage = "'type' must be BankAccount, Card or PayPal")] - public PaymentMethodType Type { get; set; } - - [Required] - public string Token { get; set; } - - public TokenizedPaymentSource ToDomain() => new(Type, Token); -} diff --git a/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs b/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs deleted file mode 100644 index 05ab1e34c9..0000000000 --- a/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs +++ /dev/null @@ -1,15 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Billing.Models.Requests; - -public class UpdatePaymentMethodRequestBody -{ - [Required] - public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } - - [Required] - public TaxInformationRequestBody TaxInformation { get; set; } -} diff --git a/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs b/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs deleted file mode 100644 index e248d55dde..0000000000 --- a/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs +++ /dev/null @@ -1,12 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Billing.Models.Requests; - -public class VerifyBankAccountRequestBody -{ - [Required] - public string DescriptorCode { get; set; } -} diff --git a/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs b/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs deleted file mode 100644 index f305e41c4f..0000000000 --- a/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs +++ /dev/null @@ -1,20 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Billing.Models; -using Bit.Core.Models.Api; - -namespace Bit.Api.Billing.Models.Responses; - -public class BillingPaymentResponseModel : ResponseModel -{ - public BillingPaymentResponseModel(BillingInfo billing) - : base("billingPayment") - { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - } - - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } -} diff --git a/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs b/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs deleted file mode 100644 index a54ac0a876..0000000000 --- a/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record PaymentMethodResponse( - decimal AccountCredit, - PaymentSource PaymentSource, - string SubscriptionStatus, - TaxInformation TaxInformation) -{ - public static PaymentMethodResponse From(PaymentMethod paymentMethod) => - new( - paymentMethod.AccountCredit, - paymentMethod.PaymentSource, - paymentMethod.SubscriptionStatus, - paymentMethod.TaxInformation); -} diff --git a/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs b/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs deleted file mode 100644 index 2c9a63b1d0..0000000000 --- a/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Api.Billing.Models.Responses; - -public record PaymentSourceResponse( - PaymentMethodType Type, - string Description, - bool NeedsVerification) -{ - public static PaymentSourceResponse From(PaymentSource paymentMethod) - => new( - paymentMethod.Type, - paymentMethod.Description, - paymentMethod.NeedsVerification); -} diff --git a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs b/src/Api/Billing/Models/Responses/TaxInformationResponse.cs deleted file mode 100644 index 59e4934751..0000000000 --- a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs +++ /dev/null @@ -1,23 +0,0 @@ -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record TaxInformationResponse( - string Country, - string PostalCode, - string TaxId, - string Line1, - string Line2, - string City, - string State) -{ - public static TaxInformationResponse From(TaxInformation taxInformation) - => new( - taxInformation.Country, - taxInformation.PostalCode, - taxInformation.TaxId, - taxInformation.Line1, - taxInformation.Line2, - taxInformation.City, - taxInformation.State); -} diff --git a/src/Api/Controllers/PhishingDomainsController.cs b/src/Api/Controllers/PhishingDomainsController.cs deleted file mode 100644 index f0c1a65648..0000000000 --- a/src/Api/Controllers/PhishingDomainsController.cs +++ /dev/null @@ -1,34 +0,0 @@ -using Bit.Core; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Controllers; - -[Route("phishing-domains")] -public class PhishingDomainsController(IPhishingDomainRepository phishingDomainRepository, IFeatureService featureService) : Controller -{ - [HttpGet] - public async Task>> GetPhishingDomainsAsync() - { - if (!featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - return NotFound(); - } - - var domains = await phishingDomainRepository.GetActivePhishingDomainsAsync(); - return Ok(domains); - } - - [HttpGet("checksum")] - public async Task> GetChecksumAsync() - { - if (!featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - return NotFound(); - } - - var checksum = await phishingDomainRepository.GetCurrentChecksumAsync(); - return Ok(checksum); - } -} diff --git a/src/Api/Dirt/Controllers/HibpController.cs b/src/Api/Dirt/Controllers/HibpController.cs index d108fdbd4f..8060384502 100644 --- a/src/Api/Dirt/Controllers/HibpController.cs +++ b/src/Api/Dirt/Controllers/HibpController.cs @@ -66,7 +66,10 @@ public class HibpController : Controller } else if (response.StatusCode == HttpStatusCode.NotFound) { - return new NotFoundResult(); + /* 12/1/2025 - Per the HIBP API, If the domain does not have any email addresses in any breaches, + an HTTP 404 response will be returned. API also specifies that "404 Not found is the account could + not be found and has therefore not been pwned". Per REST semantics we will return 200 OK with empty array. */ + return Content("[]", "application/json"); } else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) { diff --git a/src/Api/Jobs/JobsHostedService.cs b/src/Api/Jobs/JobsHostedService.cs index 0178f6d68b..a9626dc90e 100644 --- a/src/Api/Jobs/JobsHostedService.cs +++ b/src/Api/Jobs/JobsHostedService.cs @@ -59,13 +59,6 @@ public class JobsHostedService : BaseJobsHostedService .StartNow() .WithCronSchedule("0 0 * * * ?") .Build(); - var updatePhishingDomainsTrigger = TriggerBuilder.Create() - .WithIdentity("UpdatePhishingDomainsTrigger") - .StartNow() - .WithSimpleSchedule(x => x - .WithIntervalInHours(24) - .RepeatForever()) - .Build(); var updateOrgSubscriptionsTrigger = TriggerBuilder.Create() .WithIdentity("UpdateOrgSubscriptionsTrigger") .StartNow() @@ -81,7 +74,6 @@ public class JobsHostedService : BaseJobsHostedService new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger), new Tuple(typeof(ValidateOrganizationDomainJob), validateOrganizationDomainTrigger), - new Tuple(typeof(UpdatePhishingDomainsJob), updatePhishingDomainsTrigger), new (typeof(OrganizationSubscriptionUpdateJob), updateOrgSubscriptionsTrigger), }; @@ -111,7 +103,6 @@ public class JobsHostedService : BaseJobsHostedService services.AddTransient(); services.AddTransient(); services.AddTransient(); - services.AddTransient(); services.AddTransient(); } diff --git a/src/Api/Jobs/UpdatePhishingDomainsJob.cs b/src/Api/Jobs/UpdatePhishingDomainsJob.cs deleted file mode 100644 index 355f2af69b..0000000000 --- a/src/Api/Jobs/UpdatePhishingDomainsJob.cs +++ /dev/null @@ -1,97 +0,0 @@ -using Bit.Core; -using Bit.Core.Jobs; -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Settings; -using Quartz; - -namespace Bit.Api.Jobs; - -public class UpdatePhishingDomainsJob : BaseJob -{ - private readonly GlobalSettings _globalSettings; - private readonly IPhishingDomainRepository _phishingDomainRepository; - private readonly ICloudPhishingDomainQuery _cloudPhishingDomainQuery; - private readonly IFeatureService _featureService; - public UpdatePhishingDomainsJob( - GlobalSettings globalSettings, - IPhishingDomainRepository phishingDomainRepository, - ICloudPhishingDomainQuery cloudPhishingDomainQuery, - IFeatureService featureService, - ILogger logger) - : base(logger) - { - _globalSettings = globalSettings; - _phishingDomainRepository = phishingDomainRepository; - _cloudPhishingDomainQuery = cloudPhishingDomainQuery; - _featureService = featureService; - } - - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - if (!_featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Feature flag is disabled."); - return; - } - - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. No URL configured."); - return; - } - - if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Cloud communication is disabled in global settings."); - return; - } - - var remoteChecksum = await _cloudPhishingDomainQuery.GetRemoteChecksumAsync(); - if (string.IsNullOrWhiteSpace(remoteChecksum)) - { - _logger.LogWarning(Constants.BypassFiltersEventId, "Could not retrieve remote checksum. Skipping update."); - return; - } - - var currentChecksum = await _phishingDomainRepository.GetCurrentChecksumAsync(); - - if (string.Equals(currentChecksum, remoteChecksum, StringComparison.OrdinalIgnoreCase)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Phishing domains list is up to date (checksum: {Checksum}). Skipping update.", - currentChecksum); - return; - } - - _logger.LogInformation(Constants.BypassFiltersEventId, - "Checksums differ (current: {CurrentChecksum}, remote: {RemoteChecksum}). Fetching updated domains from {Source}.", - currentChecksum, remoteChecksum, _globalSettings.SelfHosted ? "Bitwarden cloud API" : "external source"); - - try - { - var domains = await _cloudPhishingDomainQuery.GetPhishingDomainsAsync(); - if (!domains.Contains("phishing.testcategory.com", StringComparer.OrdinalIgnoreCase)) - { - domains.Add("phishing.testcategory.com"); - } - - if (domains.Count > 0) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Updating {Count} phishing domains with checksum {Checksum}.", - domains.Count, remoteChecksum); - await _phishingDomainRepository.UpdatePhishingDomainsAsync(domains, remoteChecksum); - _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated phishing domains."); - } - else - { - _logger.LogWarning(Constants.BypassFiltersEventId, "No valid domains found in the response. Skipping update."); - } - } - catch (Exception ex) - { - _logger.LogError(Constants.BypassFiltersEventId, ex, "Error updating phishing domains."); - } - } -} diff --git a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs index 7968970048..b944cdd052 100644 --- a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs +++ b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs @@ -1,8 +1,8 @@ -#nullable enable -using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.WebAuthn; using Bit.Api.KeyManagement.Models.Requests; +using Bit.Api.KeyManagement.Models.Responses; using Bit.Api.KeyManagement.Validators; using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; @@ -14,6 +14,7 @@ using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Commands.Interfaces; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Services; @@ -45,11 +46,13 @@ public class AccountsKeyManagementController : Controller private readonly IRotationValidator, IEnumerable> _webauthnKeyValidator; private readonly IRotationValidator, IEnumerable> _deviceValidator; + private readonly IKeyConnectorConfirmationDetailsQuery _keyConnectorConfirmationDetailsQuery; public AccountsKeyManagementController(IUserService userService, IFeatureService featureService, IOrganizationUserRepository organizationUserRepository, IEmergencyAccessRepository emergencyAccessRepository, + IKeyConnectorConfirmationDetailsQuery keyConnectorConfirmationDetailsQuery, IRegenerateUserAsymmetricKeysCommand regenerateUserAsymmetricKeysCommand, IRotateUserAccountKeysCommand rotateUserKeyCommandV2, IRotationValidator, IEnumerable> cipherValidator, @@ -75,12 +78,13 @@ public class AccountsKeyManagementController : Controller _organizationUserValidator = organizationUserValidator; _webauthnKeyValidator = webAuthnKeyValidator; _deviceValidator = deviceValidator; + _keyConnectorConfirmationDetailsQuery = keyConnectorConfirmationDetailsQuery; } [HttpPost("key-management/regenerate-keys")] public async Task RegenerateKeysAsync([FromBody] KeyRegenerationRequestModel request) { - if (!_featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration)) + if (!_featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration) && !_featureService.IsEnabled(FeatureFlagKeys.DataRecoveryTool)) { throw new NotFoundException(); } @@ -178,4 +182,17 @@ public class AccountsKeyManagementController : Controller throw new BadRequestException(ModelState); } + + [HttpGet("key-connector/confirmation-details/{orgSsoIdentifier}")] + public async Task GetKeyConnectorConfirmationDetailsAsync(string orgSsoIdentifier) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var details = await _keyConnectorConfirmationDetailsQuery.Run(orgSsoIdentifier, user.Id); + return new KeyConnectorConfirmationDetailsResponseModel(details); + } } diff --git a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs index 02780b015a..3510be9546 100644 --- a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs @@ -1,4 +1,5 @@ using System.ComponentModel.DataAnnotations; +using Bit.Core.KeyManagement.Models.Api.Request; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs b/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs new file mode 100644 index 0000000000..68d2c689df --- /dev/null +++ b/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs @@ -0,0 +1,24 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Api.KeyManagement.Models.Responses; + +public class KeyConnectorConfirmationDetailsResponseModel : ResponseModel +{ + private const string _objectName = "keyConnectorConfirmationDetails"; + + public KeyConnectorConfirmationDetailsResponseModel(KeyConnectorConfirmationDetails details, + string obj = _objectName) : base(obj) + { + ArgumentNullException.ThrowIfNull(details); + + OrganizationName = details.OrganizationName; + } + + public KeyConnectorConfirmationDetailsResponseModel() : base(_objectName) + { + OrganizationName = string.Empty; + } + + public string OrganizationName { get; set; } +} diff --git a/src/Api/Models/Response/KeysResponseModel.cs b/src/Api/Models/Response/KeysResponseModel.cs index cfc1a6a0a1..4c877e0bfc 100644 --- a/src/Api/Models/Response/KeysResponseModel.cs +++ b/src/Api/Models/Response/KeysResponseModel.cs @@ -1,27 +1,32 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; namespace Bit.Api.Models.Response; public class KeysResponseModel : ResponseModel { - public KeysResponseModel(User user) + public KeysResponseModel(UserAccountKeysData accountKeys, string? masterKeyWrappedUserKey) : base("keys") { - if (user == null) + if (masterKeyWrappedUserKey != null) { - throw new ArgumentNullException(nameof(user)); + Key = masterKeyWrappedUserKey; } - Key = user.Key; - PublicKey = user.PublicKey; - PrivateKey = user.PrivateKey; + PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + PrivateKey = accountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + AccountKeys = new PrivateKeysResponseModel(accountKeys); } - public string Key { get; set; } + /// + /// The master key wrapped user key. The master key can either be a master-password master key or a + /// key-connector master key. + /// + public string? Key { get; set; } + [Obsolete("Use AccountKeys.PublicKeyEncryptionKeyPair.PublicKey instead")] public string PublicKey { get; set; } + [Obsolete("Use AccountKeys.PublicKeyEncryptionKeyPair.WrappedPrivateKey instead")] public string PrivateKey { get; set; } + public PrivateKeysResponseModel AccountKeys { get; set; } } diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index 29a47e160c..32d12aa416 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -1,4 +1,7 @@ -using Bit.Core.Billing.Constants; +using System.Security.Claims; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Models.Business; using Bit.Core.Entities; using Bit.Core.Models.Api; @@ -37,6 +40,46 @@ public class SubscriptionResponseModel : ResponseModel : null; } + /// The user entity containing storage and premium subscription information + /// Subscription information retrieved from the payment provider (Stripe/Braintree) + /// The user's license containing expiration and feature entitlements + /// The claims principal containing cryptographically secure token claims + /// + /// Whether to include discount information in the response. + /// Set to true when the PM23341_Milestone_2 feature flag is enabled AND + /// you want to expose Milestone 2 discount information to the client. + /// The discount will only be included if it matches the specific Milestone 2 coupon ID. + /// + public SubscriptionResponseModel(User user, SubscriptionInfo? subscription, UserLicense license, ClaimsPrincipal? claimsPrincipal, bool includeMilestone2Discount = false) + : base("subscription") + { + Subscription = subscription?.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription?.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; + License = license; + + // CRITICAL: When a license has a Token (JWT), ALWAYS use the expiration from the token claim + // The token's expiration is cryptographically secured and cannot be tampered with + // The file's Expires property can be manually edited and should NOT be trusted for display + if (claimsPrincipal != null) + { + Expiration = claimsPrincipal.GetValue(UserLicenseConstants.Expires); + } + else + { + // No token - use the license file expiration (for older licenses without tokens) + Expiration = License.Expires; + } + + // Only display the Milestone 2 subscription discount on the subscription page. + CustomerDiscount = ShouldIncludeMilestone2Discount(includeMilestone2Discount, subscription?.CustomerDiscount) + ? new BillingCustomerDiscount(subscription!.CustomerDiscount!) + : null; + } + public SubscriptionResponseModel(User user, UserLicense? license = null) : base("subscription") { diff --git a/src/Api/Program.cs b/src/Api/Program.cs index 6023f51c6d..bf924af47f 100644 --- a/src/Api/Program.cs +++ b/src/Api/Program.cs @@ -1,9 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using AspNetCoreRateLimit; -using Bit.Core.Utilities; -using Microsoft.IdentityModel.Tokens; +using Bit.Core.Utilities; namespace Bit.Api; @@ -17,32 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Exception != null && - (e.Exception.GetType() == typeof(SecurityTokenValidationException) || - e.Exception.Message == "Bad security stamp.")) - { - return false; - } - - if ( - context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return e.Level >= globalSettings.MinLogLevel.ApiSettings.IpRateLimit; - } - - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.ApiSettings.IdentityToken; - } - - return e.Level >= globalSettings.MinLogLevel.ApiSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Api/SecretsManager/Controllers/SecretVersionsController.cs b/src/Api/SecretsManager/Controllers/SecretVersionsController.cs new file mode 100644 index 0000000000..86e2d1f7e9 --- /dev/null +++ b/src/Api/SecretsManager/Controllers/SecretVersionsController.cs @@ -0,0 +1,337 @@ +using Bit.Api.Models.Response; +using Bit.Api.SecretsManager.Models.Request; +using Bit.Api.SecretsManager.Models.Response; +using Bit.Core.Auth.Identity; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.SecretsManager.Controllers; + +[Authorize("secrets")] +public class SecretVersionsController : Controller +{ + private readonly ICurrentContext _currentContext; + private readonly ISecretVersionRepository _secretVersionRepository; + private readonly ISecretRepository _secretRepository; + private readonly IUserService _userService; + private readonly IOrganizationUserRepository _organizationUserRepository; + + public SecretVersionsController( + ICurrentContext currentContext, + ISecretVersionRepository secretVersionRepository, + ISecretRepository secretRepository, + IUserService userService, + IOrganizationUserRepository organizationUserRepository) + { + _currentContext = currentContext; + _secretVersionRepository = secretVersionRepository; + _secretRepository = secretRepository; + _userService = userService; + _organizationUserRepository = organizationUserRepository; + } + + [HttpGet("secrets/{secretId}/versions")] + public async Task> GetVersionsBySecretIdAsync([FromRoute] Guid secretId) + { + var secret = await _secretRepository.GetByIdAsync(secretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access above + var versionList = await _secretVersionRepository.GetManyBySecretIdAsync(secretId); + var responseList = versionList.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(responseList); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretId, userId.Value, accessClient); + if (!access.Read) + { + throw new NotFoundException(); + } + + var versions = await _secretVersionRepository.GetManyBySecretIdAsync(secretId); + var responses = versions.Select(v => new SecretVersionResponseModel(v)); + + return new ListResponseModel(responses); + } + + [HttpGet("secret-versions/{id}")] + public async Task GetByIdAsync([FromRoute] Guid id) + { + var secretVersion = await _secretVersionRepository.GetByIdAsync(id); + if (secretVersion == null) + { + throw new NotFoundException(); + } + + var secret = await _secretRepository.GetByIdAsync(secretVersion.SecretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access above + return new SecretVersionResponseModel(secretVersion); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretVersion.SecretId, userId.Value, accessClient); + if (!access.Read) + { + throw new NotFoundException(); + } + + return new SecretVersionResponseModel(secretVersion); + } + + [HttpPost("secret-versions/get-by-ids")] + public async Task> GetManyByIdsAsync([FromBody] List ids) + { + if (!ids.Any()) + { + throw new BadRequestException("No version IDs provided."); + } + + // Get all versions + var versions = (await _secretVersionRepository.GetManyByIdsAsync(ids)).ToList(); + if (!versions.Any()) + { + throw new NotFoundException(); + } + + // Get all associated secrets and check permissions + var secretIds = versions.Select(v => v.SecretId).Distinct().ToList(); + var secrets = (await _secretRepository.GetManyByIds(secretIds)).ToList(); + + if (!secrets.Any()) + { + throw new NotFoundException(); + } + + // Ensure all secrets belong to the same organization + var organizationId = secrets.First().OrganizationId; + if (secrets.Any(s => s.OrganizationId != organizationId) || + !_currentContext.AccessSecretsManager(organizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access and organization ownership above + var serviceAccountResponses = versions.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(serviceAccountResponses); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var isAdmin = await _currentContext.OrganizationAdmin(organizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, isAdmin); + + // Verify read access to all associated secrets + var accessResults = await _secretRepository.AccessToSecretsAsync(secretIds, userId.Value, accessClient); + if (accessResults.Values.Any(access => !access.Read)) + { + throw new NotFoundException(); + } + + var responses = versions.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(responses); + } + + [HttpPut("secrets/{secretId}/versions/restore")] + public async Task RestoreVersionAsync([FromRoute] Guid secretId, [FromBody] RestoreSecretVersionRequestModel request) + { + if (!(_currentContext.IdentityClientType == IdentityClientType.User || _currentContext.IdentityClientType == IdentityClientType.ServiceAccount)) + { + throw new NotFoundException(); + } + + var secret = await _secretRepository.GetByIdAsync(secretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // Get the version first to validate it belongs to this secret + var version = await _secretVersionRepository.GetByIdAsync(request.VersionId); + if (version == null || version.SecretId != secretId) + { + throw new NotFoundException(); + } + + // Store the current value before restoration + var currentValue = secret.Value; + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount) + { + // Save current value as a version before restoring + if (currentValue != version.Value) + { + var editorUserId = _userService.GetProperUserId(User); + if (editorUserId.HasValue) + { + var currentVersionSnapshot = new Core.SecretsManager.Entities.SecretVersion + { + SecretId = secretId, + Value = currentValue!, + VersionDate = DateTime.UtcNow, + EditorServiceAccountId = editorUserId.Value + }; + + await _secretVersionRepository.CreateAsync(currentVersionSnapshot); + } + } + + // Already verified Secrets Manager access above + secret.Value = version.Value; + secret.RevisionDate = DateTime.UtcNow; + var updatedSec = await _secretRepository.UpdateAsync(secret); + return new SecretResponseModel(updatedSec, true, true); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretId, userId.Value, accessClient); + if (!access.Write) + { + throw new NotFoundException(); + } + + // Save current value as a version before restoring + if (currentValue != version.Value) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(secret.OrganizationId, userId.Value); + if (orgUser == null) + { + throw new NotFoundException(); + } + + var currentVersionSnapshot = new Core.SecretsManager.Entities.SecretVersion + { + SecretId = secretId, + Value = currentValue!, + VersionDate = DateTime.UtcNow, + EditorOrganizationUserId = orgUser.Id + }; + + await _secretVersionRepository.CreateAsync(currentVersionSnapshot); + } + + // Update the secret with the version's value + secret.Value = version.Value; + secret.RevisionDate = DateTime.UtcNow; + + var updatedSecret = await _secretRepository.UpdateAsync(secret); + + return new SecretResponseModel(updatedSecret, true, true); + } + + [HttpPost("secret-versions/delete")] + public async Task BulkDeleteAsync([FromBody] List ids) + { + if (!ids.Any()) + { + throw new BadRequestException("No version IDs provided."); + } + + var secretVersions = (await _secretVersionRepository.GetManyByIdsAsync(ids)).ToList(); + if (secretVersions.Count != ids.Count) + { + throw new NotFoundException(); + } + + // Ensure all versions belong to secrets in the same organization + var secretIds = secretVersions.Select(v => v.SecretId).Distinct().ToList(); + var secrets = await _secretRepository.GetManyByIds(secretIds); + var secretsList = secrets.ToList(); + + if (!secretsList.Any()) + { + throw new NotFoundException(); + } + + var organizationId = secretsList.First().OrganizationId; + if (secretsList.Any(s => s.OrganizationId != organizationId) || + !_currentContext.AccessSecretsManager(organizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access and organization ownership above + await _secretVersionRepository.DeleteManyByIdAsync(ids); + return Ok(); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + // Verify write access to all associated secrets + var accessResults = await _secretRepository.AccessToSecretsAsync(secretIds, userId.Value, accessClient); + if (accessResults.Values.Any(access => !access.Write)) + { + throw new NotFoundException(); + } + + await _secretVersionRepository.DeleteManyByIdAsync(ids); + + return Ok(); + } +} diff --git a/src/Api/SecretsManager/Controllers/SecretsController.cs b/src/Api/SecretsManager/Controllers/SecretsController.cs index e263b9747d..dcfe1be111 100644 --- a/src/Api/SecretsManager/Controllers/SecretsController.cs +++ b/src/Api/SecretsManager/Controllers/SecretsController.cs @@ -8,6 +8,7 @@ using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Repositories; using Bit.Core.SecretsManager.AuthorizationRequirements; using Bit.Core.SecretsManager.Commands.Secrets.Interfaces; using Bit.Core.SecretsManager.Entities; @@ -29,6 +30,7 @@ public class SecretsController : Controller private readonly ICurrentContext _currentContext; private readonly IProjectRepository _projectRepository; private readonly ISecretRepository _secretRepository; + private readonly ISecretVersionRepository _secretVersionRepository; private readonly ICreateSecretCommand _createSecretCommand; private readonly IUpdateSecretCommand _updateSecretCommand; private readonly IDeleteSecretCommand _deleteSecretCommand; @@ -38,11 +40,13 @@ public class SecretsController : Controller private readonly IUserService _userService; private readonly IEventService _eventService; private readonly IAuthorizationService _authorizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; public SecretsController( ICurrentContext currentContext, IProjectRepository projectRepository, ISecretRepository secretRepository, + ISecretVersionRepository secretVersionRepository, ICreateSecretCommand createSecretCommand, IUpdateSecretCommand updateSecretCommand, IDeleteSecretCommand deleteSecretCommand, @@ -51,11 +55,13 @@ public class SecretsController : Controller ISecretAccessPoliciesUpdatesQuery secretAccessPoliciesUpdatesQuery, IUserService userService, IEventService eventService, - IAuthorizationService authorizationService) + IAuthorizationService authorizationService, + IOrganizationUserRepository organizationUserRepository) { _currentContext = currentContext; _projectRepository = projectRepository; _secretRepository = secretRepository; + _secretVersionRepository = secretVersionRepository; _createSecretCommand = createSecretCommand; _updateSecretCommand = updateSecretCommand; _deleteSecretCommand = deleteSecretCommand; @@ -65,6 +71,7 @@ public class SecretsController : Controller _userService = userService; _eventService = eventService; _authorizationService = authorizationService; + _organizationUserRepository = organizationUserRepository; } @@ -190,6 +197,44 @@ public class SecretsController : Controller } } + // Create a version record if the value changed + if (updateRequest.ValueChanged) + { + // Store the old value before updating + var oldValue = secret.Value; + var userId = _userService.GetProperUserId(User)!.Value; + Guid? editorServiceAccountId = null; + Guid? editorOrganizationUserId = null; + + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount) + { + editorServiceAccountId = userId; + } + else if (_currentContext.IdentityClientType == IdentityClientType.User) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(secret.OrganizationId, userId); + if (orgUser != null) + { + editorOrganizationUserId = orgUser.Id; + } + else + { + throw new NotFoundException(); + } + } + + var secretVersion = new SecretVersion + { + SecretId = id, + Value = oldValue, + VersionDate = DateTime.UtcNow, + EditorServiceAccountId = editorServiceAccountId, + EditorOrganizationUserId = editorOrganizationUserId + }; + + await _secretVersionRepository.CreateAsync(secretVersion); + } + var result = await _updateSecretCommand.UpdateAsync(updatedSecret, accessPoliciesUpdates); await LogSecretEventAsync(secret, EventType.Secret_Edited); diff --git a/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs b/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs new file mode 100644 index 0000000000..19a6b35a75 --- /dev/null +++ b/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs @@ -0,0 +1,9 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.SecretsManager.Models.Request; + +public class RestoreSecretVersionRequestModel +{ + [Required] + public Guid VersionId { get; set; } +} diff --git a/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs b/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs index b95bc9e500..9d19e1d8cc 100644 --- a/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs +++ b/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs @@ -28,6 +28,8 @@ public class SecretUpdateRequestModel : IValidatableObject public SecretAccessPoliciesRequestsModel AccessPoliciesRequests { get; set; } + public bool ValueChanged { get; set; } = false; + public Secret ToSecret(Secret secret) { secret.Key = Key; diff --git a/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs b/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs new file mode 100644 index 0000000000..07b8e88f7e --- /dev/null +++ b/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs @@ -0,0 +1,28 @@ +using Bit.Core.Models.Api; +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Api.SecretsManager.Models.Response; + +public class SecretVersionResponseModel : ResponseModel +{ + private const string _objectName = "secretVersion"; + + public Guid Id { get; set; } + public Guid SecretId { get; set; } + public string Value { get; set; } = string.Empty; + public DateTime VersionDate { get; set; } + public Guid? EditorServiceAccountId { get; set; } + public Guid? EditorOrganizationUserId { get; set; } + + public SecretVersionResponseModel() : base(_objectName) { } + + public SecretVersionResponseModel(SecretVersion secretVersion) : base(_objectName) + { + Id = secretVersion.Id; + SecretId = secretVersion.SecretId; + Value = secretVersion.Value; + VersionDate = secretVersion.VersionDate; + EditorServiceAccountId = secretVersion.EditorServiceAccountId; + EditorOrganizationUserId = secretVersion.EditorOrganizationUserId; + } +} diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 0967b4f662..2f16470cd4 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -187,7 +187,6 @@ public class Startup services.AddBillingOperations(); services.AddReportingServices(); services.AddImportServices(); - services.AddPhishingDomainServices(globalSettings); services.AddSendServices(); @@ -216,7 +215,7 @@ public class Startup config.Conventions.Add(new PublicApiControllersModelConvention()); }); - services.AddSwagger(globalSettings, Environment); + services.AddSwaggerGen(globalSettings, Environment); Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); services.AddHostedService(); @@ -226,7 +225,8 @@ public class Startup services.AddHostedService(); } - // Add Slack / Teams Services for OAuth API requests - if configured + // Add Event Integrations services + services.AddEventIntegrationsCommandsQueries(globalSettings); services.AddSlackService(globalSettings); services.AddTeamsService(globalSettings); } @@ -234,12 +234,10 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings, ILogger logger) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); @@ -294,17 +292,59 @@ public class Startup }); // Add Swagger + // Note that the swagger.json generation is configured in the call to AddSwaggerGen above. if (Environment.IsDevelopment() || globalSettings.SelfHosted) { + // adds the middleware to serve the swagger.json while the server is running app.UseSwagger(config => { config.RouteTemplate = "specs/{documentName}/swagger.json"; + + // Remove all Bitwarden cloud servers and only register the local server config.PreSerializeFilters.Add((swaggerDoc, httpReq) => - swaggerDoc.Servers = new List + { + swaggerDoc.Servers.Clear(); + swaggerDoc.Servers.Add(new OpenApiServer { - new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } + Url = globalSettings.BaseServiceUri.Api, }); + + swaggerDoc.Components.SecuritySchemes.Clear(); + swaggerDoc.Components.SecuritySchemes.Add("oauth2-client-credentials", new OpenApiSecurityScheme + { + Type = SecuritySchemeType.OAuth2, + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), + Scopes = new Dictionary + { + { ApiScopes.ApiOrganization, "Organization APIs" } + } + } + } + }); + + swaggerDoc.SecurityRequirements.Clear(); + swaggerDoc.SecurityRequirements.Add(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = "oauth2-client-credentials" + } + }, + [ApiScopes.ApiOrganization] + } + }); + }); }); + + // adds the middleware to display the web UI app.UseSwaggerUI(config => { config.DocumentTitle = "Bitwarden API Documentation"; diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index 6af688f548..b773abf6ef 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -1,15 +1,11 @@ using Bit.Api.AdminConsole.Authorization; using Bit.Api.Tools.Authorization; -using Bit.Core.Auth.IdentityServer; -using Bit.Core.PhishingDomainFeatures; -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Repositories.Implementations; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Core.Vault.Authorization.SecurityTasks; using Bit.SharedWeb.Health; using Bit.SharedWeb.Swagger; +using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.OpenApi.Models; @@ -17,7 +13,10 @@ namespace Bit.Api.Utilities; public static class ServiceCollectionExtensions { - public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment) + /// + /// Configures the generation of swagger.json OpenAPI spec. + /// + public static void AddSwaggerGen(this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment) { services.AddSwaggerGen(config => { @@ -36,6 +35,8 @@ public static class ServiceCollectionExtensions organizations tools for managing members, collections, groups, event logs, and policies. If you are looking for the Vault Management API, refer instead to [this document](https://bitwarden.com/help/vault-management-api/). + + **Note:** your authorization must match the server you have selected. """, License = new OpenApiLicense { @@ -46,36 +47,20 @@ public static class ServiceCollectionExtensions config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); - config.AddSecurityDefinition("oauth2-client-credentials", new OpenApiSecurityScheme - { - Type = SecuritySchemeType.OAuth2, - Flows = new OpenApiOAuthFlows - { - ClientCredentials = new OpenApiOAuthFlow - { - TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), - Scopes = new Dictionary - { - { ApiScopes.ApiOrganization, "Organization APIs" }, - }, - } - }, - }); + // Configure Bitwarden cloud US and EU servers. These will appear in the swagger.json build artifact + // used for our help center. These are overwritten with the local server when running in self-hosted + // or dev mode (see Api Startup.cs). + config.AddSwaggerServerWithSecurity( + serverId: "US_server", + serverUrl: "https://api.bitwarden.com", + identityTokenUrl: "https://identity.bitwarden.com/connect/token", + serverDescription: "US server"); - config.AddSecurityRequirement(new OpenApiSecurityRequirement - { - { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = "oauth2-client-credentials" - }, - }, - new[] { ApiScopes.ApiOrganization } - } - }); + config.AddSwaggerServerWithSecurity( + serverId: "EU_server", + serverUrl: "https://api.bitwarden.eu", + identityTokenUrl: "https://identity.bitwarden.eu/connect/token", + serverDescription: "EU server"); config.DescribeAllParametersInCamelCase(); // config.UseReferencedDefinitionsForEnums(); @@ -114,25 +99,4 @@ public static class ServiceCollectionExtensions // Admin Console authorization handlers services.AddAdminConsoleAuthorizationHandlers(); } - - public static void AddPhishingDomainServices(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddHttpClient("PhishingDomains", client => - { - client.DefaultRequestHeaders.Add("User-Agent", globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"); - client.Timeout = TimeSpan.FromSeconds(1000); // the source list is very slow - }); - - services.AddSingleton(); - services.AddSingleton(); - - if (globalSettings.SelfHosted) - { - services.AddScoped(); - } - else - { - services.AddScoped(); - } - } } diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index 0983225f84..6a506cc01f 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -757,15 +757,10 @@ public class CiphersController : Controller } } - if (cipher.ArchivedDate.HasValue) - { - throw new BadRequestException("Cannot move an archived item to an organization."); - } - ValidateClientVersionForFido2CredentialSupport(cipher); var original = cipher.Clone(); - await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), + await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher, user.Id), new Guid(model.Cipher.OrganizationId), model.CollectionIds.Select(c => new Guid(c)), user.Id, model.Cipher.LastKnownRevisionDate); var sharedCipher = await GetByIdAsync(id, user.Id); @@ -1271,11 +1266,6 @@ public class CiphersController : Controller _logger.LogError("Cipher was not encrypted for the current user. CipherId: {CipherId}, CurrentUser: {CurrentUserId}, EncryptedFor: {EncryptedFor}", cipher.Id, userId, cipher.EncryptedFor); throw new BadRequestException("Cipher was not encrypted for the current user. Please try again."); } - - if (cipher.ArchivedDate.HasValue) - { - throw new BadRequestException("Cannot move archived items to an organization."); - } } var shareCiphers = new List<(CipherDetails, DateTime?)>(); @@ -1288,11 +1278,6 @@ public class CiphersController : Controller ValidateClientVersionForFido2CredentialSupport(existingCipher); - if (existingCipher.ArchivedDate.HasValue) - { - throw new BadRequestException("Cannot move archived items to an organization."); - } - shareCiphers.Add((cipher.ToCipherDetails(existingCipher), cipher.LastKnownRevisionDate)); } @@ -1422,11 +1407,9 @@ public class CiphersController : Controller throw new NotFoundException(); } - // Extract lastKnownRevisionDate from form data if present - DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream) => { - await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData, lastKnownRevisionDate); + await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); }); } @@ -1525,13 +1508,10 @@ public class CiphersController : Controller throw new NotFoundException(); } - // Extract lastKnownRevisionDate from form data if present - DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); - await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentShareAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId, lastKnownRevisionDate); + Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); }); } diff --git a/src/Api/Vault/Models/Request/CipherRequestModel.cs b/src/Api/Vault/Models/Request/CipherRequestModel.cs index b0589a62f9..18a1aec559 100644 --- a/src/Api/Vault/Models/Request/CipherRequestModel.cs +++ b/src/Api/Vault/Models/Request/CipherRequestModel.cs @@ -84,7 +84,7 @@ public class CipherRequestModel return existingCipher; } - public Cipher ToCipher(Cipher existingCipher) + public Cipher ToCipher(Cipher existingCipher, Guid? userId = null) { // If Data field is provided, use it directly if (!string.IsNullOrWhiteSpace(Data)) @@ -124,9 +124,12 @@ public class CipherRequestModel } } + var userIdKey = userId.HasValue ? userId.ToString().ToUpperInvariant() : null; existingCipher.Reprompt = Reprompt; existingCipher.Key = Key; existingCipher.ArchivedDate = ArchivedDate; + existingCipher.Folders = UpdateUserSpecificJsonField(existingCipher.Folders, userIdKey, FolderId); + existingCipher.Favorites = UpdateUserSpecificJsonField(existingCipher.Favorites, userIdKey, Favorite); var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; var hasAttachments = (Attachments?.Count ?? 0) > 0; @@ -291,6 +294,37 @@ public class CipherRequestModel KeyFingerprint = SSHKey.KeyFingerprint, }; } + + /// + /// Updates a JSON string representing a dictionary by adding, updating, or removing a key-value pair + /// based on the provided userIdKey and newValue. + /// + private static string UpdateUserSpecificJsonField(string existingJson, string userIdKey, object newValue) + { + if (userIdKey == null) + { + return existingJson; + } + + var jsonDict = string.IsNullOrWhiteSpace(existingJson) + ? new Dictionary() + : JsonSerializer.Deserialize>(existingJson) ?? new Dictionary(); + + var shouldRemove = newValue == null || + (newValue is string strValue && string.IsNullOrWhiteSpace(strValue)) || + (newValue is bool boolValue && !boolValue); + + if (shouldRemove) + { + jsonDict.Remove(userIdKey); + } + else + { + jsonDict[userIdKey] = newValue is string str ? str.ToUpperInvariant() : newValue; + } + + return jsonDict.Count == 0 ? null : JsonSerializer.Serialize(jsonDict); + } } public class CipherWithIdRequestModel : CipherRequestModel diff --git a/src/Api/appsettings.Development.json b/src/Api/appsettings.Development.json index 82fb951261..deb0a35d84 100644 --- a/src/Api/appsettings.Development.json +++ b/src/Api/appsettings.Development.json @@ -38,9 +38,6 @@ "storage": { "connectionString": "UseDevelopmentStorage=true" }, - "phishingDomain": { - "updateUrl": "https://phish.co.za/latest/phishing-domains-ACTIVE.txt", - "checksumUrl": "https://raw.githubusercontent.com/Phishing-Database/checksums/refs/heads/master/phishing-domains-ACTIVE.txt.sha256" - } + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index 98bb4df8ac..8850c3d269 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -32,9 +32,6 @@ "send": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" @@ -72,9 +69,6 @@ "accessKeySecret": "SECRET", "region": "SECRET" }, - "phishingDomain": { - "updateUrl": "SECRET" - }, "distributedIpRateLimiting": { "enabled": true, "maxRedisTimeoutsThreshold": 10, diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj index e2b7447eb7..69999dc795 100644 --- a/src/Billing/Billing.csproj +++ b/src/Billing/Billing.csproj @@ -1,9 +1,17 @@  + bitwarden-Billing + + + false + false + false + + @@ -11,7 +19,7 @@ - + diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index b24a8d8c36..f55b4523af 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -29,7 +29,7 @@ public class BitPayController( IUserRepository userRepository, IProviderRepository providerRepository, IMailService mailService, - IPaymentService paymentService, + IStripePaymentService paymentService, ILogger logger, IPremiumUserBillingService premiumUserBillingService) : Controller diff --git a/src/Billing/Controllers/JobsController.cs b/src/Billing/Controllers/JobsController.cs new file mode 100644 index 0000000000..6a5e8e5531 --- /dev/null +++ b/src/Billing/Controllers/JobsController.cs @@ -0,0 +1,36 @@ +using Bit.Billing.Jobs; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Billing.Controllers; + +[Route("jobs")] +[SelfHosted(NotSelfHostedOnly = true)] +[RequireLowerEnvironment] +public class JobsController( + JobsHostedService jobsHostedService) : Controller +{ + [HttpPost("run/{jobName}")] + public async Task RunJobAsync(string jobName) + { + if (jobName == nameof(ReconcileAdditionalStorageJob)) + { + await jobsHostedService.RunJobAdHocAsync(); + return Ok(new { message = $"Job {jobName} scheduled successfully" }); + } + + return BadRequest(new { error = $"Unknown job name: {jobName}" }); + } + + [HttpPost("stop/{jobName}")] + public async Task StopJobAsync(string jobName) + { + if (jobName == nameof(ReconcileAdditionalStorageJob)) + { + await jobsHostedService.InterruptAdHocJobAsync(); + return Ok(new { message = $"Job {jobName} queued for cancellation" }); + } + + return BadRequest(new { error = $"Unknown job name: {jobName}" }); + } +} diff --git a/src/Billing/Controllers/PayPalController.cs b/src/Billing/Controllers/PayPalController.cs index 8039680fd5..70023b6bdb 100644 --- a/src/Billing/Controllers/PayPalController.cs +++ b/src/Billing/Controllers/PayPalController.cs @@ -23,7 +23,7 @@ public class PayPalController : Controller private readonly ILogger _logger; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ITransactionRepository _transactionRepository; private readonly IUserRepository _userRepository; private readonly IProviderRepository _providerRepository; @@ -34,7 +34,7 @@ public class PayPalController : Controller ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ITransactionRepository transactionRepository, IUserRepository userRepository, IProviderRepository providerRepository, diff --git a/src/Billing/Jobs/AliveJob.cs b/src/Billing/Jobs/AliveJob.cs index 42f64099ac..1769cc94e2 100644 --- a/src/Billing/Jobs/AliveJob.cs +++ b/src/Billing/Jobs/AliveJob.cs @@ -10,4 +10,13 @@ public class AliveJob(ILogger logger) : BaseJob(logger) _logger.LogInformation(Core.Constants.BypassFiltersEventId, null, "Billing service is alive!"); return Task.FromResult(0); } + + public static ITrigger GetTrigger() + { + return TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + } } diff --git a/src/Billing/Jobs/JobsHostedService.cs b/src/Billing/Jobs/JobsHostedService.cs index a6e702c662..25c57044da 100644 --- a/src/Billing/Jobs/JobsHostedService.cs +++ b/src/Billing/Jobs/JobsHostedService.cs @@ -1,29 +1,27 @@ -using Bit.Core.Jobs; +using Bit.Core.Exceptions; +using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; namespace Bit.Billing.Jobs; -public class JobsHostedService : BaseJobsHostedService +public class JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger, + ISchedulerFactory schedulerFactory) + : BaseJobsHostedService(globalSettings, serviceProvider, logger, listenerLogger) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } + private List AdHocJobKeys { get; } = []; + private IScheduler? _adHocScheduler; public override async Task StartAsync(CancellationToken cancellationToken) { - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - Jobs = new List> { - new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger) + new(typeof(AliveJob), AliveJob.GetTrigger()), + new(typeof(ReconcileAdditionalStorageJob), ReconcileAdditionalStorageJob.GetTrigger()) }; await base.StartAsync(cancellationToken); @@ -33,5 +31,54 @@ public class JobsHostedService : BaseJobsHostedService { services.AddTransient(); services.AddTransient(); + services.AddTransient(); + // add this service as a singleton so we can inject it where needed + services.AddSingleton(); + services.AddHostedService(sp => sp.GetRequiredService()); + } + + public async Task InterruptAdHocJobAsync(CancellationToken cancellationToken = default) where T : class, IJob + { + if (_adHocScheduler == null) + { + throw new InvalidOperationException("AdHocScheduler is null, cannot interrupt ad-hoc job."); + } + + var jobKey = AdHocJobKeys.FirstOrDefault(j => j.Name == typeof(T).ToString()); + if (jobKey == null) + { + throw new NotFoundException($"Cannot find job key: {typeof(T)}, not running?"); + } + logger.LogInformation("CANCELLING ad-hoc job with key: {JobKey}", jobKey); + AdHocJobKeys.Remove(jobKey); + await _adHocScheduler.Interrupt(jobKey, cancellationToken); + } + + public async Task RunJobAdHocAsync(CancellationToken cancellationToken = default) where T : class, IJob + { + _adHocScheduler ??= await schedulerFactory.GetScheduler(cancellationToken); + + var jobKey = new JobKey(typeof(T).ToString()); + + var currentlyExecuting = await _adHocScheduler.GetCurrentlyExecutingJobs(cancellationToken); + if (currentlyExecuting.Any(j => j.JobDetail.Key.Equals(jobKey))) + { + throw new InvalidOperationException($"Job {jobKey} is already running"); + } + + AdHocJobKeys.Add(jobKey); + + var job = JobBuilder.Create() + .WithIdentity(jobKey) + .Build(); + + var trigger = TriggerBuilder.Create() + .WithIdentity(typeof(T).ToString()) + .StartNow() + .Build(); + + logger.LogInformation("Scheduling ad-hoc job with key: {JobKey}", jobKey); + + await _adHocScheduler.ScheduleJob(job, trigger, cancellationToken); } } diff --git a/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs b/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs new file mode 100644 index 0000000000..312ed3122b --- /dev/null +++ b/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs @@ -0,0 +1,193 @@ +using System.Globalization; +using System.Text.Json; +using Bit.Billing.Services; +using Bit.Core; +using Bit.Core.Billing.Constants; +using Bit.Core.Jobs; +using Bit.Core.Services; +using Quartz; +using Stripe; + +namespace Bit.Billing.Jobs; + +public class ReconcileAdditionalStorageJob( + IStripeFacade stripeFacade, + ILogger logger, + IFeatureService featureService) : BaseJob(logger) +{ + private const string _storageGbMonthlyPriceId = "storage-gb-monthly"; + private const string _storageGbAnnuallyPriceId = "storage-gb-annually"; + private const string _personalStorageGbAnnuallyPriceId = "personal-storage-gb-annually"; + private const int _storageGbToRemove = 4; + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob)) + { + logger.LogInformation("Skipping ReconcileAdditionalStorageJob, feature flag off."); + return; + } + + var liveMode = featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode); + + // Execution tracking + var subscriptionsFound = 0; + var subscriptionsUpdated = 0; + var subscriptionsWithErrors = 0; + var failures = new List(); + + logger.LogInformation("Starting ReconcileAdditionalStorageJob (live mode: {LiveMode})", liveMode); + + var priceIds = new[] { _storageGbMonthlyPriceId, _storageGbAnnuallyPriceId, _personalStorageGbAnnuallyPriceId }; + var stripeStatusesToProcess = new[] { StripeConstants.SubscriptionStatus.Active, StripeConstants.SubscriptionStatus.Trialing, StripeConstants.SubscriptionStatus.PastDue }; + + foreach (var priceId in priceIds) + { + var options = new SubscriptionListOptions { Limit = 100, Price = priceId }; + + await foreach (var subscription in stripeFacade.ListSubscriptionsAutoPagingAsync(options)) + { + if (context.CancellationToken.IsCancellationRequested) + { + logger.LogWarning( + "Job cancelled!! Exiting. Progress at time of cancellation: Subscriptions found: {SubscriptionsFound}, " + + "Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}", + subscriptionsFound, + liveMode + ? subscriptionsUpdated + : $"(In live mode, would have updated) {subscriptionsUpdated}", + subscriptionsWithErrors, + failures.Count > 0 + ? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}" + : string.Empty + ); + return; + } + + if (subscription == null) + { + continue; + } + + if (!stripeStatusesToProcess.Contains(subscription.Status)) + { + logger.LogInformation("Skipping subscription with unsupported status: {SubscriptionId} - {Status}", subscription.Id, subscription.Status); + continue; + } + + logger.LogInformation("Processing subscription: {SubscriptionId}", subscription.Id); + subscriptionsFound++; + + if (subscription.Metadata?.TryGetValue(StripeConstants.MetadataKeys.StorageReconciled2025, out var dateString) == true) + { + if (DateTime.TryParse(dateString, null, DateTimeStyles.RoundtripKind, out var dateProcessed)) + { + logger.LogInformation("Skipping subscription {SubscriptionId} - already processed on {Date}", + subscription.Id, + dateProcessed.ToString("f")); + continue; + } + } + + var updateOptions = BuildSubscriptionUpdateOptions(subscription, priceId); + + if (updateOptions == null) + { + logger.LogInformation("Skipping subscription {SubscriptionId} - no updates needed", subscription.Id); + continue; + } + + subscriptionsUpdated++; + + if (!liveMode) + { + logger.LogInformation( + "Not live mode (dry-run): Would have updated subscription {SubscriptionId} with item changes: {NewLine}{UpdateOptions}", + subscription.Id, + Environment.NewLine, + JsonSerializer.Serialize(updateOptions)); + continue; + } + + try + { + await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); + logger.LogInformation("Successfully updated subscription: {SubscriptionId}", subscription.Id); + } + catch (Exception ex) + { + subscriptionsWithErrors++; + failures.Add($"Subscription {subscription.Id}: {ex.Message}"); + logger.LogError(ex, "Failed to update subscription {SubscriptionId}: {ErrorMessage}", + subscription.Id, ex.Message); + } + } + } + + logger.LogInformation( + "ReconcileAdditionalStorageJob completed. Subscriptions found: {SubscriptionsFound}, " + + "Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}", + subscriptionsFound, + liveMode + ? subscriptionsUpdated + : $"(In live mode, would have updated) {subscriptionsUpdated}", + subscriptionsWithErrors, + failures.Count > 0 + ? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}" + : string.Empty + ); + } + + private SubscriptionUpdateOptions? BuildSubscriptionUpdateOptions( + Subscription subscription, + string targetPriceId) + { + if (subscription.Items?.Data == null) + { + return null; + } + + var updateOptions = new SubscriptionUpdateOptions { ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, Metadata = new Dictionary { [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") }, Items = [] }; + + var hasUpdates = false; + + foreach (var item in subscription.Items.Data.Where(item => item?.Price?.Id == targetPriceId)) + { + hasUpdates = true; + var currentQuantity = item.Quantity; + + if (currentQuantity > _storageGbToRemove) + { + var newQuantity = currentQuantity - _storageGbToRemove; + logger.LogInformation( + "Subscription {SubscriptionId}: reducing quantity from {CurrentQuantity} to {NewQuantity} for price {PriceId}", + subscription.Id, + currentQuantity, + newQuantity, + item.Price.Id); + + updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Quantity = newQuantity }); + } + else + { + logger.LogInformation("Subscription {SubscriptionId}: deleting storage item with quantity {CurrentQuantity} for price {PriceId}", + subscription.Id, + currentQuantity, + item.Price.Id); + + updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Deleted = true }); + } + } + + return hasUpdates ? updateOptions : null; + } + + public static ITrigger GetTrigger() + { + return TriggerBuilder.Create() + .WithIdentity("EveryMorningTrigger") + .StartNow() + .WithCronSchedule("0 0 16 * * ?") // 10am CST daily; the pods execute in UTC time + .Build(); + } +} diff --git a/src/Billing/Jobs/SubscriptionCancellationJob.cs b/src/Billing/Jobs/SubscriptionCancellationJob.cs index 69b7bc876d..60b671df3d 100644 --- a/src/Billing/Jobs/SubscriptionCancellationJob.cs +++ b/src/Billing/Jobs/SubscriptionCancellationJob.cs @@ -1,16 +1,17 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Billing.Services; +using Bit.Billing.Services; +using Bit.Core.Billing.Constants; using Bit.Core.Repositories; using Quartz; using Stripe; namespace Bit.Billing.Jobs; +using static StripeConstants; + public class SubscriptionCancellationJob( IStripeFacade stripeFacade, - IOrganizationRepository organizationRepository) + IOrganizationRepository organizationRepository, + ILogger logger) : IJob { public async Task Execute(IJobExecutionContext context) @@ -21,20 +22,31 @@ public class SubscriptionCancellationJob( var organization = await organizationRepository.GetByIdAsync(organizationId); if (organization == null || organization.Enabled) { + logger.LogWarning("{Job} skipped for subscription ({SubscriptionID}) because organization is either null or enabled", nameof(SubscriptionCancellationJob), subscriptionId); // Organization was deleted or re-enabled by CS, skip cancellation return; } - var subscription = await stripeFacade.GetSubscription(subscriptionId); - if (subscription?.Status != "unpaid" || - subscription.LatestInvoice?.BillingReason is not ("subscription_cycle" or "subscription_create")) + var subscription = await stripeFacade.GetSubscription(subscriptionId, new SubscriptionGetOptions { + Expand = ["latest_invoice"] + }); + + if (subscription is not + { + Status: SubscriptionStatus.Unpaid, + LatestInvoice: { BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle } + }) + { + logger.LogWarning("{Job} skipped for subscription ({SubscriptionID}) because subscription is not unpaid or does not have a cancellable billing reason", nameof(SubscriptionCancellationJob), subscriptionId); return; } // Cancel the subscription await stripeFacade.CancelSubscription(subscriptionId, new SubscriptionCancelOptions()); + logger.LogInformation("{Job} cancelled subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), subscriptionId); + // Void any open invoices var options = new InvoiceListOptions { @@ -46,6 +58,7 @@ public class SubscriptionCancellationJob( foreach (var invoice in invoices) { await stripeFacade.VoidInvoice(invoice.Id); + logger.LogInformation("{Job} voided invoice ({InvoiceID}) for subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), invoice.Id, subscriptionId); } while (invoices.HasMore) @@ -55,6 +68,7 @@ public class SubscriptionCancellationJob( foreach (var invoice in invoices) { await stripeFacade.VoidInvoice(invoice.Id); + logger.LogInformation("{Job} voided invoice ({InvoiceID}) for subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), invoice.Id, subscriptionId); } } } diff --git a/src/Billing/Program.cs b/src/Billing/Program.cs index 3e005ce7fd..334dc49368 100644 --- a/src/Billing/Program.cs +++ b/src/Billing/Program.cs @@ -8,28 +8,12 @@ public class Program { Host .CreateDefaultBuilder(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs")) - { - return e.Level >= globalSettings.MinLogLevel.BillingSettings.Jobs; - } - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.BillingSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Billing/Services/IStripeEventUtilityService.cs b/src/Billing/Services/IStripeEventUtilityService.cs index a5f536ad11..058f56c887 100644 --- a/src/Billing/Services/IStripeEventUtilityService.cs +++ b/src/Billing/Services/IStripeEventUtilityService.cs @@ -36,7 +36,7 @@ public interface IStripeEventUtilityService /// /// /// /// - Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId); + Task FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId); /// /// Attempts to pay the specified invoice. If a customer is eligible, the invoice is paid using Braintree or Stripe. diff --git a/src/Billing/Services/IStripeFacade.cs b/src/Billing/Services/IStripeFacade.cs index 280a3aca3c..c7073b9cf9 100644 --- a/src/Billing/Services/IStripeFacade.cs +++ b/src/Billing/Services/IStripeFacade.cs @@ -20,6 +20,12 @@ public interface IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GetCustomerCashBalanceTransactions( + string customerId, + CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task UpdateCustomer( string customerId, CustomerUpdateOptions customerUpdateOptions = null, @@ -78,6 +84,11 @@ public interface IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + IAsyncEnumerable ListSubscriptionsAutoPagingAsync( + SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -111,4 +122,10 @@ public interface IStripeFacade TestClockGetOptions testClockGetOptions = null, RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + + Task GetCoupon( + string couponId, + CouponGetOptions couponGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); } diff --git a/src/Billing/Services/Implementations/ChargeRefundedHandler.cs b/src/Billing/Services/Implementations/ChargeRefundedHandler.cs index 905491b6c5..8cc3cb2ce6 100644 --- a/src/Billing/Services/Implementations/ChargeRefundedHandler.cs +++ b/src/Billing/Services/Implementations/ChargeRefundedHandler.cs @@ -38,7 +38,7 @@ public class ChargeRefundedHandler : IChargeRefundedHandler { // Attempt to create a transaction for the charge if it doesn't exist var (organizationId, userId, providerId) = await _stripeEventUtilityService.GetEntityIdsFromChargeAsync(charge); - var tx = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId); + var tx = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId); try { parentTransaction = await _transactionRepository.CreateAsync(tx); diff --git a/src/Billing/Services/Implementations/ChargeSucceededHandler.cs b/src/Billing/Services/Implementations/ChargeSucceededHandler.cs index bd8ea7def2..20c4dcfa98 100644 --- a/src/Billing/Services/Implementations/ChargeSucceededHandler.cs +++ b/src/Billing/Services/Implementations/ChargeSucceededHandler.cs @@ -46,7 +46,7 @@ public class ChargeSucceededHandler : IChargeSucceededHandler return; } - var transaction = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId); + var transaction = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId); if (!transaction.PaymentMethodType.HasValue) { _logger.LogWarning("Charge success from unsupported source/method. {ChargeId}", charge.Id); diff --git a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs index bc3fa1bd56..89e40f0e43 100644 --- a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs @@ -2,8 +2,8 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using OneOf; using Stripe; using Event = Stripe.Event; @@ -59,10 +59,10 @@ public class SetupIntentSucceededHandler( return; } - await stripeAdapter.PaymentMethodAttachAsync(paymentMethod.Id, + await stripeAdapter.AttachPaymentMethodAsync(paymentMethod.Id, new PaymentMethodAttachOptions { Customer = customerId }); - await stripeAdapter.CustomerUpdateAsync(customerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { diff --git a/src/Billing/Services/Implementations/StripeEventUtilityService.cs b/src/Billing/Services/Implementations/StripeEventUtilityService.cs index 49e562de56..53512427c0 100644 --- a/src/Billing/Services/Implementations/StripeEventUtilityService.cs +++ b/src/Billing/Services/Implementations/StripeEventUtilityService.cs @@ -3,12 +3,12 @@ using Bit.Billing.Constants; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Models; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; using Braintree; using Stripe; using Customer = Stripe.Customer; @@ -112,7 +112,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService } public bool IsSponsoredSubscription(Subscription subscription) => - StaticStore.SponsoredPlans + SponsoredPlans.All .Any(p => subscription.Items .Any(i => i.Plan.Id == p.StripePlanId)); @@ -124,7 +124,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService /// /// /// /// - public Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId) + public async Task FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId) { var transaction = new Transaction { @@ -209,6 +209,24 @@ public class StripeEventUtilityService : IStripeEventUtilityService transaction.PaymentMethodType = PaymentMethodType.BankAccount; transaction.Details = $"ACH => {achCreditTransfer.BankName}, {achCreditTransfer.AccountNumber}"; } + else if (charge.PaymentMethodDetails.CustomerBalance != null) + { + var bankTransferType = await GetFundingBankTransferTypeAsync(charge); + + if (!string.IsNullOrEmpty(bankTransferType)) + { + transaction.PaymentMethodType = PaymentMethodType.BankAccount; + transaction.Details = bankTransferType switch + { + "eu_bank_transfer" => "EU Bank Transfer", + "gb_bank_transfer" => "GB Bank Transfer", + "jp_bank_transfer" => "JP Bank Transfer", + "mx_bank_transfer" => "MX Bank Transfer", + "us_bank_transfer" => "US Bank Transfer", + _ => "Bank Transfer" + }; + } + } break; } @@ -289,20 +307,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService } var btInvoiceAmount = Math.Round(invoice.AmountDue / 100M, 2); - var existingTransactions = organizationId.HasValue - ? await _transactionRepository.GetManyByOrganizationIdAsync(organizationId.Value) - : userId.HasValue - ? await _transactionRepository.GetManyByUserIdAsync(userId.Value) - : await _transactionRepository.GetManyByProviderIdAsync(providerId.Value); - - var duplicateTimeSpan = TimeSpan.FromHours(24); - var now = DateTime.UtcNow; - var duplicateTransaction = existingTransactions? - .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); - if (duplicateTransaction != null) + // Check if this invoice already has a Braintree transaction ID to prevent duplicate charges + if (invoice.Metadata?.ContainsKey("btTransactionId") ?? false) { - _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + - "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); + _logger.LogWarning("Invoice {InvoiceId} already has a Braintree transaction ({TransactionId}). " + + "Do not charge again to prevent duplicate.", + invoice.Id, + invoice.Metadata["btTransactionId"]); return false; } @@ -413,4 +424,55 @@ public class StripeEventUtilityService : IStripeEventUtilityService throw; } } + + /// + /// Retrieves the bank transfer type that funded a charge paid via customer balance. + /// + /// The charge to analyze. + /// + /// The bank transfer type (e.g., "us_bank_transfer", "eu_bank_transfer") if the charge was funded + /// by a bank transfer via customer balance, otherwise null. + /// + private async Task GetFundingBankTransferTypeAsync(Charge charge) + { + if (charge is not + { + CustomerId: not null, + PaymentIntentId: not null, + PaymentMethodDetails: { Type: "customer_balance" } + }) + { + return null; + } + + var cashBalanceTransactions = _stripeFacade.GetCustomerCashBalanceTransactions(charge.CustomerId); + + string bankTransferType = null; + var matchingPaymentIntentFound = false; + + await foreach (var cashBalanceTransaction in cashBalanceTransactions) + { + switch (cashBalanceTransaction) + { + case { Type: "funded", Funded: not null }: + { + bankTransferType = cashBalanceTransaction.Funded.BankTransfer.Type; + break; + } + case { Type: "applied_to_payment", AppliedToPayment: not null } + when cashBalanceTransaction.AppliedToPayment.PaymentIntentId == charge.PaymentIntentId: + { + matchingPaymentIntentFound = true; + break; + } + } + + if (matchingPaymentIntentFound && !string.IsNullOrEmpty(bankTransferType)) + { + return bankTransferType; + } + } + + return null; + } } diff --git a/src/Billing/Services/Implementations/StripeFacade.cs b/src/Billing/Services/Implementations/StripeFacade.cs index eef7ce009e..49cde981cd 100644 --- a/src/Billing/Services/Implementations/StripeFacade.cs +++ b/src/Billing/Services/Implementations/StripeFacade.cs @@ -11,6 +11,7 @@ public class StripeFacade : IStripeFacade { private readonly ChargeService _chargeService = new(); private readonly CustomerService _customerService = new(); + private readonly CustomerCashBalanceTransactionService _customerCashBalanceTransactionService = new(); private readonly EventService _eventService = new(); private readonly InvoiceService _invoiceService = new(); private readonly PaymentMethodService _paymentMethodService = new(); @@ -18,6 +19,7 @@ public class StripeFacade : IStripeFacade private readonly DiscountService _discountService = new(); private readonly SetupIntentService _setupIntentService = new(); private readonly TestClockService _testClockService = new(); + private readonly CouponService _couponService = new(); public async Task GetCharge( string chargeId, @@ -40,6 +42,13 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _customerService.GetAsync(customerId, customerGetOptions, requestOptions, cancellationToken); + public IAsyncEnumerable GetCustomerCashBalanceTransactions( + string customerId, + CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + => _customerCashBalanceTransactionService.ListAutoPagingAsync(customerId, customerCashBalanceTransactionListOptions, requestOptions, cancellationToken); + public async Task UpdateCustomer( string customerId, CustomerUpdateOptions customerUpdateOptions = null, @@ -98,6 +107,12 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _subscriptionService.ListAsync(options, requestOptions, cancellationToken); + public IAsyncEnumerable ListSubscriptionsAutoPagingAsync( + SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + _subscriptionService.ListAutoPagingAsync(options, requestOptions, cancellationToken); + public async Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -137,4 +152,11 @@ public class StripeFacade : IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default) => _testClockService.GetAsync(testClockId, testClockGetOptions, requestOptions, cancellationToken); + + public Task GetCoupon( + string couponId, + CouponGetOptions couponGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + _couponService.GetAsync(couponId, couponGetOptions, requestOptions, cancellationToken); } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 81aeb460c2..c10368d8c0 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,7 +1,5 @@ -using System.Globalization; -using Bit.Billing.Constants; +using Bit.Billing.Constants; using Bit.Billing.Jobs; -using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; @@ -111,8 +109,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler break; } - if (subscription.Status is StripeSubscriptionStatus.Unpaid && - subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + if (await IsPremiumSubscriptionAsync(subscription)) { await CancelSubscription(subscription.Id); await VoidOpenInvoices(subscription.Id); @@ -120,6 +117,20 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + break; + } + case StripeSubscriptionStatus.Incomplete when userId.HasValue: + { + // Handle Incomplete subscriptions for Premium users that have open invoices from failed payments + // This prevents duplicate subscriptions when users retry the subscription flow + if (await IsPremiumSubscriptionAsync(subscription) && + subscription.LatestInvoice is { Status: StripeInvoiceStatus.Open }) + { + await CancelSubscription(subscription.Id); + await VoidOpenInvoices(subscription.Id); + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + } + break; } case StripeSubscriptionStatus.Active when organizationId.HasValue: @@ -134,11 +145,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } case StripeSubscriptionStatus.Active when providerId.HasValue: { - var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); - if (!providerPortalTakeover) - { - break; - } var provider = await _providerRepository.GetByIdAsync(providerId.Value); if (provider != null) { @@ -197,6 +203,13 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } + private async Task IsPremiumSubscriptionAsync(Subscription subscription) + { + var premiumPlans = await _pricingClient.ListPremiumPlans(); + var premiumPriceIds = premiumPlans.SelectMany(p => new[] { p.Seat.StripePriceId, p.Storage.StripePriceId }).ToHashSet(); + return subscription.Items.Any(i => premiumPriceIds.Contains(i.Price.Id)); + } + /// /// Checks if the provider subscription status has changed from a non-active to an active status type /// If the previous status is already active(active,past-due,trialing),canceled,or null, then this will return false. @@ -321,13 +334,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler Event parsedEvent, Subscription currentSubscription) { - var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); - - if (!providerPortalTakeover) - { - return; - } - var provider = await _providerRepository.GetByIdAsync(providerId); if (provider == null) { @@ -343,22 +349,17 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler { var previousSubscription = parsedEvent.Data.PreviousAttributes.ToObject() as Subscription; - var updateIsSubscriptionGoingUnpaid = previousSubscription is - { - Status: + if (previousSubscription is + { + Status: StripeSubscriptionStatus.Trialing or StripeSubscriptionStatus.Active or StripeSubscriptionStatus.PastDue - } && currentSubscription is - { - Status: StripeSubscriptionStatus.Unpaid, - LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" - }; - - var updateIsManualSuspensionViaMetadata = CheckForManualSuspensionViaMetadata( - previousSubscription, currentSubscription); - - if (updateIsSubscriptionGoingUnpaid || updateIsManualSuspensionViaMetadata) + } && currentSubscription is + { + Status: StripeSubscriptionStatus.Unpaid, + LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" + }) { if (currentSubscription.TestClock != null) { @@ -369,14 +370,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler var subscriptionUpdateOptions = new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }; - if (updateIsManualSuspensionViaMetadata) - { - subscriptionUpdateOptions.Metadata = new Dictionary - { - ["suspended_provider_via_webhook_at"] = DateTime.UtcNow.ToString(CultureInfo.InvariantCulture) - }; - } - await _stripeFacade.UpdateSubscription(currentSubscription.Id, subscriptionUpdateOptions); } } @@ -399,37 +392,4 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } } - - private static bool CheckForManualSuspensionViaMetadata( - Subscription? previousSubscription, - Subscription currentSubscription) - { - /* - * When metadata on a subscription is updated, we'll receive an event that has: - * Previous Metadata: { newlyAddedKey: null } - * Current Metadata: { newlyAddedKey: newlyAddedValue } - * - * As such, our check for a manual suspension must ensure that the 'previous_attributes' does contain the - * 'metadata' property, but also that the "suspend_provider" key in that metadata is set to null. - * - * If we don't do this and instead do a null coalescing check on 'previous_attributes?.metadata?.TryGetValue', - * we'll end up marking an event where 'previous_attributes.metadata' = null (which could be any subscription update - * that does not update the metadata) the same as a manual suspension. - */ - const string key = "suspend_provider"; - - if (previousSubscription is not { Metadata: not null } || - !previousSubscription.Metadata.TryGetValue(key, out var previousValue)) - { - return false; - } - - if (previousValue == null) - { - return !string.IsNullOrEmpty( - currentSubscription.Metadata.TryGetValue(key, out var currentValue) ? currentValue : null); - } - - return false; - } } diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 1db469a4e2..004828dc48 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,4 +1,5 @@ -using Bit.Core; +using System.Globalization; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; @@ -8,7 +9,9 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Entities; -using Bit.Core.Models.Mail.UpdatedInvoiceIncoming; +using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Premium; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Repositories; @@ -16,6 +19,7 @@ using Bit.Core.Services; using Stripe; using Event = Stripe.Event; using Plan = Bit.Core.Models.StaticStore.Plan; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; namespace Bit.Billing.Services.Implementations; @@ -107,13 +111,22 @@ public class UpcomingInvoiceHandler( var milestone3 = featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3); - await AlignOrganizationSubscriptionConcernsAsync( + var subscriptionAligned = await AlignOrganizationSubscriptionConcernsAsync( organization, @event, subscription, plan, milestone3); + /* + * Subscription alignment sends out a different version of our Upcoming Invoice email, so we don't need to continue + * with processing. + */ + if (subscriptionAligned) + { + return; + } + // Don't send the upcoming invoice email unless the organization's on an annual plan. if (!plan.IsAnnual) { @@ -135,9 +148,7 @@ public class UpcomingInvoiceHandler( } } - await (milestone3 - ? SendUpdatedUpcomingInvoiceEmailsAsync([organization.BillingEmail]) - : SendUpcomingInvoiceEmailsAsync([organization.BillingEmail], invoice)); + await SendUpcomingInvoiceEmailsAsync([organization.BillingEmail], invoice); } private async Task AlignOrganizationTaxConcernsAsync( @@ -188,47 +199,64 @@ public class UpcomingInvoiceHandler( } } - private async Task AlignOrganizationSubscriptionConcernsAsync( + /// + /// Aligns the organization's subscription details with the specified plan and milestone requirements. + /// + /// The organization whose subscription is being updated. + /// The Stripe event associated with this operation. + /// The organization's subscription. + /// The organization's current plan. + /// A flag indicating whether the third milestone is enabled. + /// Whether the operation resulted in an updated subscription. + private async Task AlignOrganizationSubscriptionConcernsAsync( Organization organization, Event @event, Subscription subscription, Plan plan, bool milestone3) { - if (milestone3 && plan.Type == PlanType.FamiliesAnnually2019) + // currently these are the only plans that need aligned and both require the same flag and share most of the logic + if (!milestone3 || plan.Type is not (PlanType.FamiliesAnnually2019 or PlanType.FamiliesAnnually2025)) { - var passwordManagerItem = - subscription.Items.FirstOrDefault(item => item.Price.Id == plan.PasswordManager.StripePlanId); + return false; + } - if (passwordManagerItem == null) - { - logger.LogWarning("Could not find Organization's ({OrganizationId}) password manager item while processing '{EventType}' event ({EventID})", - organization.Id, @event.Type, @event.Id); - return; - } + var passwordManagerItem = + subscription.Items.FirstOrDefault(item => item.Price.Id == plan.PasswordManager.StripePlanId); - var families = await pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually); + if (passwordManagerItem == null) + { + logger.LogWarning("Could not find Organization's ({OrganizationId}) password manager item while processing '{EventType}' event ({EventID})", + organization.Id, @event.Type, @event.Id); + return false; + } - organization.PlanType = families.Type; - organization.Plan = families.Name; - organization.UsersGetPremium = families.UsersGetPremium; - organization.Seats = families.PasswordManager.BaseSeats; + var familiesPlan = await pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually); - var options = new SubscriptionUpdateOptions - { - Items = - [ - new SubscriptionItemOptions - { - Id = passwordManagerItem.Id, Price = families.PasswordManager.StripePlanId - } - ], - Discounts = - [ - new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone3SubscriptionDiscount } - ], - ProrationBehavior = ProrationBehavior.None - }; + organization.PlanType = familiesPlan.Type; + organization.Plan = familiesPlan.Name; + organization.UsersGetPremium = familiesPlan.UsersGetPremium; + organization.Seats = familiesPlan.PasswordManager.BaseSeats; + + var options = new SubscriptionUpdateOptions + { + Items = + [ + new SubscriptionItemOptions + { + Id = passwordManagerItem.Id, + Price = familiesPlan.PasswordManager.StripePlanId + } + ], + ProrationBehavior = ProrationBehavior.None + }; + + if (plan.Type == PlanType.FamiliesAnnually2019) + { + options.Discounts = + [ + new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone3SubscriptionDiscount } + ]; var premiumAccessAddOnItem = subscription.Items.FirstOrDefault(item => item.Price.Id == plan.PasswordManager.StripePremiumAccessPlanId); @@ -242,21 +270,35 @@ public class UpcomingInvoiceHandler( }); } - try + var seatAddOnItem = subscription.Items.FirstOrDefault(item => item.Price.Id == "personal-org-seat-annually"); + + if (seatAddOnItem != null) { - await organizationRepository.ReplaceAsync(organization); - await stripeFacade.UpdateSubscription(subscription.Id, options); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to align subscription concerns for Organization ({OrganizationID}) while processing '{EventType}' event ({EventID})", - organization.Id, - @event.Type, - @event.Id); + options.Items.Add(new SubscriptionItemOptions + { + Id = seatAddOnItem.Id, + Deleted = true + }); } } + + try + { + await organizationRepository.ReplaceAsync(organization); + await stripeFacade.UpdateSubscription(subscription.Id, options); + await SendFamiliesRenewalEmailAsync(organization, familiesPlan, plan); + return true; + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to align subscription concerns for Organization ({OrganizationID}) while processing '{EventType}' event ({EventID})", + organization.Id, + @event.Type, + @event.Id); + return false; + } } #endregion @@ -284,14 +326,21 @@ public class UpcomingInvoiceHandler( var milestone2Feature = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); if (milestone2Feature) { - await AlignPremiumUsersSubscriptionConcernsAsync(user, @event, subscription); + var subscriptionAligned = await AlignPremiumUsersSubscriptionConcernsAsync(user, @event, subscription); + + /* + * Subscription alignment sends out a different version of our Upcoming Invoice email, so we don't need to continue + * with processing. + */ + if (subscriptionAligned) + { + return; + } } if (user.Premium) { - await (milestone2Feature - ? SendUpdatedUpcomingInvoiceEmailsAsync(new List { user.Email }) - : SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice)); + await SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice); } } @@ -322,7 +371,7 @@ public class UpcomingInvoiceHandler( } } - private async Task AlignPremiumUsersSubscriptionConcernsAsync( + private async Task AlignPremiumUsersSubscriptionConcernsAsync( User user, Event @event, Subscription subscription) @@ -333,7 +382,7 @@ public class UpcomingInvoiceHandler( { logger.LogWarning("Could not find User's ({UserID}) premium subscription item while processing '{EventType}' event ({EventID})", user.Id, @event.Type, @event.Id); - return; + return false; } try @@ -352,6 +401,8 @@ public class UpcomingInvoiceHandler( ], ProrationBehavior = ProrationBehavior.None }); + await SendPremiumRenewalEmailAsync(user, plan); + return true; } catch (Exception exception) { @@ -360,6 +411,7 @@ public class UpcomingInvoiceHandler( "Failed to update user's ({UserID}) subscription price id while processing event with ID {EventID}", user.Id, @event.Id); + return false; } } @@ -494,15 +546,92 @@ public class UpcomingInvoiceHandler( } } - private async Task SendUpdatedUpcomingInvoiceEmailsAsync(IEnumerable emails) + private async Task SendFamiliesRenewalEmailAsync( + Organization organization, + Plan familiesPlan, + Plan planBeforeAlignment) { - var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); - var updatedUpcomingEmail = new UpdatedInvoiceUpcomingMail + await (planBeforeAlignment switch { - ToEmails = validEmails, - View = new UpdatedInvoiceUpcomingView() + { Type: PlanType.FamiliesAnnually2025 } => SendFamilies2020RenewalEmailAsync(organization, familiesPlan), + { Type: PlanType.FamiliesAnnually2019 } => SendFamilies2019RenewalEmailAsync(organization, familiesPlan), + _ => throw new InvalidOperationException("Unsupported families plan in SendFamiliesRenewalEmailAsync().") + }); + } + + private async Task SendFamilies2020RenewalEmailAsync(Organization organization, Plan familiesPlan) + { + var email = new Families2020RenewalMail + { + ToEmails = [organization.BillingEmail], + View = new Families2020RenewalMailView + { + MonthlyRenewalPrice = (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) + } }; - await mailer.SendEmail(updatedUpcomingEmail); + + await mailer.SendEmail(email); + } + + private async Task SendFamilies2019RenewalEmailAsync(Organization organization, Plan familiesPlan) + { + var coupon = await stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + if (coupon == null) + { + throw new InvalidOperationException($"Coupon for sending families 2019 email id:{CouponIDs.Milestone3SubscriptionDiscount} not found"); + } + + if (coupon.PercentOff == null) + { + throw new InvalidOperationException($"coupon.PercentOff for sending families 2019 email id:{CouponIDs.Milestone3SubscriptionDiscount} is null"); + } + + var discountedAnnualRenewalPrice = familiesPlan.PasswordManager.BasePrice * (100 - coupon.PercentOff.Value) / 100; + + var email = new Families2019RenewalMail + { + ToEmails = [organization.BillingEmail], + View = new Families2019RenewalMailView + { + BaseMonthlyRenewalPrice = (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")), + BaseAnnualRenewalPrice = familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")), + DiscountAmount = $"{coupon.PercentOff}%", + DiscountedAnnualRenewalPrice = discountedAnnualRenewalPrice.ToString("C", new CultureInfo("en-US")) + } + }; + + await mailer.SendEmail(email); + } + + private async Task SendPremiumRenewalEmailAsync( + User user, + PremiumPlan premiumPlan) + { + var coupon = await stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount); + if (coupon == null) + { + throw new InvalidOperationException($"Coupon for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} not found"); + } + + if (coupon.PercentOff == null) + { + throw new InvalidOperationException($"coupon.PercentOff for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} is null"); + } + + var discountedAnnualRenewalPrice = premiumPlan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; + + var email = new PremiumRenewalMail + { + ToEmails = [user.Email], + View = new PremiumRenewalMailView + { + BaseMonthlyRenewalPrice = (premiumPlan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")), + DiscountAmount = $"{coupon.PercentOff}%", + DiscountedMonthlyRenewalPrice = (discountedAnnualRenewalPrice / 12).ToString("C", new CultureInfo("en-US")) + } + }; + + await mailer.SendEmail(email); } #endregion diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index cdb9700ad5..1343dc0895 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -10,7 +10,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Context; using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories.Noop; -using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -129,12 +128,8 @@ public class Startup public void Configure( IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + IWebHostEnvironment env) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Billing/appsettings.Development.json b/src/Billing/appsettings.Development.json index 7c4889c22f..fe8e47b2f6 100644 --- a/src/Billing/appsettings.Development.json +++ b/src/Billing/appsettings.Development.json @@ -35,6 +35,7 @@ "billingSettings": { "onyx": { "personaId": 68 - } - } + } + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } diff --git a/src/Billing/appsettings.json b/src/Billing/appsettings.json index a2d6acd0a1..aa14f1d377 100644 --- a/src/Billing/appsettings.json +++ b/src/Billing/appsettings.json @@ -30,9 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 73aa162f22..338b150de6 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -134,6 +134,11 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable /// public bool UseAutomaticUserConfirmation { get; set; } + /// + /// If set to true, the organization has phishing protection enabled. + /// + public bool UsePhishingBlocker { get; set; } + public void SetNewId() { if (Id == default(Guid)) @@ -334,5 +339,6 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable UseOrganizationDomains = license.UseOrganizationDomains; UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies; UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation; + UsePhishingBlocker = license.UsePhishingBlocker; } } diff --git a/src/Core/AdminConsole/Entities/OrganizationIntegration.cs b/src/Core/AdminConsole/Entities/OrganizationIntegration.cs index 86de25ce9a..f1c96c8b98 100644 --- a/src/Core/AdminConsole/Entities/OrganizationIntegration.cs +++ b/src/Core/AdminConsole/Entities/OrganizationIntegration.cs @@ -2,8 +2,6 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -#nullable enable - namespace Bit.Core.AdminConsole.Entities; public class OrganizationIntegration : ITableObject diff --git a/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs b/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs index 52934cf7f3..a9ce676062 100644 --- a/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs +++ b/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs @@ -2,8 +2,6 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -#nullable enable - namespace Bit.Core.AdminConsole.Entities; public class OrganizationIntegrationConfiguration : ITableObject diff --git a/src/Core/AdminConsole/Enums/EventType.cs b/src/Core/AdminConsole/Enums/EventType.cs index 8073938fc5..916f408fe6 100644 --- a/src/Core/AdminConsole/Enums/EventType.cs +++ b/src/Core/AdminConsole/Enums/EventType.cs @@ -60,6 +60,7 @@ public enum EventType : int OrganizationUser_RejectedAuthRequest = 1514, OrganizationUser_Deleted = 1515, // Both user and organization user data were deleted OrganizationUser_Left = 1516, // User voluntarily left the organization + OrganizationUser_AutomaticallyConfirmed = 1517, Organization_Updated = 1600, Organization_PurgedVault = 1601, @@ -80,6 +81,8 @@ public enum EventType : int Organization_CollectionManagement_LimitItemDeletionDisabled = 1615, Organization_CollectionManagement_AllowAdminAccessToAllCollectionItemsEnabled = 1616, Organization_CollectionManagement_AllowAdminAccessToAllCollectionItemsDisabled = 1617, + Organization_ItemOrganization_Accepted = 1618, + Organization_ItemOrganization_Declined = 1619, Policy_Updated = 1700, diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 09fa4ec955..bd6daf7cdf 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -21,6 +21,7 @@ public enum PolicyType : byte UriMatchDefaults = 16, AutotypeDefaultSetting = 17, AutomaticUserConfirmation = 18, + BlockClaimedDomainAccountCreation = 19, } public static class PolicyTypeExtensions @@ -52,6 +53,7 @@ public static class PolicyTypeExtensions PolicyType.UriMatchDefaults => "URI match defaults", PolicyType.AutotypeDefaultSetting => "Autotype default setting", PolicyType.AutomaticUserConfirmation => "Automatically confirm invited users", + PolicyType.BlockClaimedDomainAccountCreation => "Block account creation for claimed domains", }; } } diff --git a/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs b/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs new file mode 100644 index 0000000000..5dce52d907 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs @@ -0,0 +1,559 @@ +using Azure.Messaging.ServiceBus; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Models.Teams; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.Services.NoopImplementations; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; +using TableStorageRepos = Bit.Core.Repositories.TableStorage; + +namespace Microsoft.Extensions.DependencyInjection; + +public static class EventIntegrationsServiceCollectionExtensions +{ + /// + /// Adds all event integrations commands, queries, and required cache infrastructure. + /// This method is idempotent and can be called multiple times safely. + /// + public static IServiceCollection AddEventIntegrationsCommandsQueries( + this IServiceCollection services, + GlobalSettings globalSettings) + { + // Ensure cache is registered first - commands depend on this keyed cache. + // This is idempotent for the same named cache, so it's safe to call. + services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + + // Add Validator + services.TryAddSingleton(); + + // Add all commands/queries + services.AddOrganizationIntegrationCommandsQueries(); + services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + return services; + } + + /// + /// Registers event write services based on available configuration. + /// + /// The service collection to add services to. + /// The global settings containing event logging configuration. + /// The service collection for chaining. + /// + /// + /// This method registers the appropriate IEventWriteService implementation based on the available + /// configuration, checking in the following priority order: + /// + /// + /// 1. Azure Service Bus - If all Azure Service Bus settings are present, registers + /// EventIntegrationEventWriteService with AzureServiceBusService as the publisher + /// + /// + /// 2. RabbitMQ - If all RabbitMQ settings are present, registers EventIntegrationEventWriteService with + /// RabbitMqService as the publisher + /// + /// + /// 3. Azure Queue Storage - If Events.ConnectionString is present, registers AzureQueueEventWriteService + /// + /// + /// 4. Repository (Self-Hosted) - If SelfHosted is true, registers RepositoryEventWriteService + /// + /// + /// 5. Noop - If none of the above are configured, registers NoopEventWriteService (no-op implementation) + /// + /// + public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) + { + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.Events.QueueName)) + { + services.TryAddSingleton(); + return services; + } + + if (globalSettings.SelfHosted) + { + services.TryAddSingleton(); + return services; + } + + services.TryAddSingleton(); + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing Azure Service Bus configuration. + /// The service collection for chaining. + /// + /// + /// If Azure Service Bus is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When Azure Service Bus is enabled, this method registers: + /// - IAzureServiceBusService and IEventIntegrationPublisher implementations + /// - Table Storage event repository + /// - Azure Table Storage event handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsAzureServiceBusEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing RabbitMQ configuration. + /// The service collection for chaining. + /// + /// + /// If RabbitMQ is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When RabbitMQ is enabled, this method registers: + /// - IRabbitMqService and IEventIntegrationPublisher implementations + /// - Event repository handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsRabbitMqEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers Slack integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Slack configuration. + /// The service collection for chaining. + /// + /// If all required Slack settings are configured (ClientId, ClientSecret, Scopes), registers the full SlackService, + /// including an HttpClient for Slack API calls. Otherwise, registers a NoopSlackService that performs no operations. + /// + public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) + { + services.AddHttpClient(SlackService.HttpClientName); + services.TryAddSingleton(); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers Microsoft Teams integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Teams configuration. + /// The service collection for chaining. + /// + /// If all required Teams settings are configured (ClientId, ClientSecret, Scopes), registers: + /// - TeamsService and its interfaces (IBot, ITeamsService) + /// - IBotFrameworkHttpAdapter with Teams credentials + /// - HttpClient for Teams API calls + /// Otherwise, registers a NoopTeamsService that performs no operations. + /// + public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) + { + services.AddHttpClient(TeamsService.HttpClientName); + services.TryAddSingleton(); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(_ => + new BotFrameworkHttpAdapter( + new TeamsBotCredentialProvider( + clientId: globalSettings.Teams.ClientId, + clientSecret: globalSettings.Teams.ClientSecret + ) + ) + ); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers event integration services including handlers, listeners, and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing integration configuration. + /// The service collection for chaining. + /// + /// + /// This method orchestrates the registration of all event integration components based on the enabled + /// message broker (Azure Service Bus or RabbitMQ). It is an internal method called by the public + /// entry points AddAzureServiceBusListeners and AddRabbitMqListeners. + /// + /// + /// NOTE: If both Azure Service Bus and RabbitMQ are configured, Azure Service Bus takes precedence. This means that + /// Azure Service Bus listeners will be registered (and RabbitMQ listeners will NOT) even if this event is called + /// from AddRabbitMqListeners when Azure Service Bus settings are configured. + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before invoking this method. + /// This method depends on distributed cache infrastructure being available for the keyed extended + /// cache registration. + /// + /// + /// Registered Services: + /// - Keyed ExtendedCache for event integrations + /// - Integration filter service + /// - Integration handlers for Slack, Webhook, Hec, Datadog, and Teams + /// - Hosted services for event and integration listeners (based on enabled message broker) + /// + /// + internal static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, + GlobalSettings globalSettings) + { + // Add common services + // NOTE: AddDistributedCache must be called by the caller before this method + services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + + // Add services in support of handlers + services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); + services.TryAddSingleton(TimeProvider.System); + services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); + services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); + + // Add integration handlers + services.TryAddSingleton, SlackIntegrationHandler>(); + services.TryAddSingleton, WebhookIntegrationHandler>(); + services.TryAddSingleton, DatadogIntegrationHandler>(); + services.TryAddSingleton, TeamsIntegrationHandler>(); + + var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); + var slackConfiguration = new SlackListenerConfiguration(globalSettings); + var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); + var hecConfiguration = new HecListenerConfiguration(globalSettings); + var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); + var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); + + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: repositoryConfiguration, + handler: provider.GetRequiredService(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = repositoryConfiguration.EventPrefetchCount, + MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddAzureServiceBusIntegration(slackConfiguration); + services.AddAzureServiceBusIntegration(webhookConfiguration); + services.AddAzureServiceBusIntegration(hecConfiguration); + services.AddAzureServiceBusIntegration(datadogConfiguration); + services.AddAzureServiceBusIntegration(teamsConfiguration); + + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredService(), + configuration: repositoryConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddRabbitMqIntegration(slackConfiguration); + services.AddRabbitMqIntegration(webhookConfiguration); + services.AddRabbitMqIntegration(hecConfiguration); + services.AddRabbitMqIntegration(datadogConfiguration); + services.AddRabbitMqIntegration(teamsConfiguration); + } + + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. AzureServiceBusEventListenerService - Hosted service for listening to event messages from Azure Service Bus + /// for this integration type + /// 3. AzureServiceBusIntegrationListenerService - Hosted service for listening to integration messages from + /// Azure Service Bus for this integration type + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + /// Service Bus processor options (PrefetchCount and MaxConcurrentCalls) are configured from the listener + /// configuration to optimize message throughput and concurrency. + /// + /// + internal static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.EventPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusIntegrationListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredService>(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. RabbitMqEventListenerService - Hosted service for listening to event messages from RabbitMQ for + /// this integration type + /// 3. RabbitMqIntegrationListenerService - Hosted service for listening to integration messages from RabbitMQ for + /// this integration type + /// + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + internal static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqIntegrationListenerService( + handler: provider.GetRequiredService>(), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService(), + timeProvider: provider.GetRequiredService() + ) + ) + ); + + return services; + } + + internal static IServiceCollection AddOrganizationIntegrationCommandsQueries(this IServiceCollection services) + { + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + + return services; + } + + internal static IServiceCollection AddOrganizationIntegrationConfigurationCommandsQueries(this IServiceCollection services) + { + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + + return services; + } + + /// + /// Determines if RabbitMQ is enabled for event integrations based on configuration settings. + /// + /// The global settings containing RabbitMQ configuration. + /// True if all required RabbitMQ settings are present; otherwise, false. + /// + /// Requires all the following settings to be configured: + /// - EventLogging.RabbitMq.HostName + /// - EventLogging.RabbitMq.Username + /// - EventLogging.RabbitMq.Password + /// - EventLogging.RabbitMq.EventExchangeName + /// + internal static bool IsRabbitMqEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName); + } + + /// + /// Determines if Azure Service Bus is enabled for event integrations based on configuration settings. + /// + /// The global settings containing Azure Service Bus configuration. + /// True if all required Azure Service Bus settings are present; otherwise, false. + /// + /// Requires both of the following settings to be configured: + /// - EventLogging.AzureServiceBus.ConnectionString + /// - EventLogging.AzureServiceBus.EventTopicName + /// + internal static bool IsAzureServiceBusEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName); + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..cb3ce8b9ea --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,64 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for creating organization integration configurations with validation and cache invalidation support. +/// +public class CreateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : ICreateOrganizationIntegrationConfigurationCommand +{ + public async Task CreateAsync( + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, configuration)) + { + throw new BadRequestException( + $"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + var created = await configurationRepository.CreateAsync(configuration); + + // Invalidate the cached configuration details + // Even though this is a new record, the cache could hold a stale empty list for this + if (created.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: created.EventType.Value + )); + } + + return created; + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..78768fd0d4 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,54 @@ +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for deleting organization integration configurations with cache invalidation support. +/// +public class DeleteOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache) + : IDeleteOrganizationIntegrationConfigurationCommand +{ + public async Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + + await configurationRepository.DeleteAsync(configuration); + + if (configuration.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + } + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..a2078c3c98 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,29 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Query implementation for retrieving organization integration configurations. +/// +public class GetOrganizationIntegrationConfigurationsQuery( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository) + : IGetOrganizationIntegrationConfigurationsQuery +{ + public async Task> GetManyByIntegrationAsync( + Guid organizationId, + Guid integrationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + + var configurations = await configurationRepository.GetManyByIntegrationAsync(integrationId); + return configurations.ToList(); + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..140cc79d1a --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,22 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for creating organization integration configurations. +/// +public interface ICreateOrganizationIntegrationConfigurationCommand +{ + /// + /// Creates a new configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The configuration to create. + /// The created configuration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task CreateAsync(Guid organizationId, Guid integrationId, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..3970676d40 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,19 @@ +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for deleting organization integration configurations. +/// +public interface IDeleteOrganizationIntegrationConfigurationCommand +{ + /// + /// Deletes a configuration from an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to delete. + /// + /// Thrown when the integration or configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..2bf806c458 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,19 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Query interface for retrieving organization integration configurations. +/// +public interface IGetOrganizationIntegrationConfigurationsQuery +{ + /// + /// Retrieves all configurations for a specific organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// A list of configurations associated with the integration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + Task> GetManyByIntegrationAsync(Guid organizationId, Guid integrationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..3e60a0af07 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,25 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for updating organization integration configurations. +/// +public interface IUpdateOrganizationIntegrationConfigurationCommand +{ + /// + /// Updates an existing configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to update. + /// The updated configuration data. + /// The updated configuration. + /// + /// Thrown when the integration or the configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task UpdateAsync(Guid organizationId, Guid integrationId, Guid configurationId, OrganizationIntegrationConfiguration updatedConfiguration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..f619e2ddf2 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,82 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for updating organization integration configurations with validation and cache invalidation support. +/// +public class UpdateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : IUpdateOrganizationIntegrationConfigurationCommand +{ + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, updatedConfiguration)) + { + throw new BadRequestException($"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + updatedConfiguration.Id = configuration.Id; + updatedConfiguration.CreationDate = configuration.CreationDate; + await configurationRepository.ReplaceAsync(updatedConfiguration); + + // If either old or new EventType is null (wildcard), invalidate all cached results + // for the specific integration + if (configuration.EventType == null || updatedConfiguration.EventType == null) + { + // Wildcard involved - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + + return updatedConfiguration; + } + + // Both are specific event types - invalidate specific cache entries + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + + // If event type changed, also clear the new event type's cache + if (configuration.EventType != updatedConfiguration.EventType) + { + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: updatedConfiguration.EventType.Value + )); + } + + return updatedConfiguration; + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..376451977c --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs @@ -0,0 +1,38 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for creating organization integrations with cache invalidation support. +/// +public class CreateOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] + IFusionCache cache) + : ICreateOrganizationIntegrationCommand +{ + public async Task CreateAsync(OrganizationIntegration integration) + { + var existingIntegrations = await integrationRepository + .GetManyByOrganizationAsync(integration.OrganizationId); + if (existingIntegrations.Any(i => i.Type == integration.Type)) + { + throw new BadRequestException("An integration of this type already exists for this organization."); + } + + var created = await integrationRepository.CreateAsync(integration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: integration.OrganizationId, + integrationType: integration.Type + )); + + return created; + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..614693cd82 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs @@ -0,0 +1,33 @@ +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for deleting organization integrations with cache invalidation support. +/// +public class DeleteOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache) + : IDeleteOrganizationIntegrationCommand +{ + public async Task DeleteAsync(Guid organizationId, Guid integrationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration is null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + + await integrationRepository.DeleteAsync(integration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs new file mode 100644 index 0000000000..f7bbaadb4a --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs @@ -0,0 +1,18 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Repositories; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; + +/// +/// Query implementation for retrieving organization integrations. +/// +public class GetOrganizationIntegrationsQuery(IOrganizationIntegrationRepository integrationRepository) + : IGetOrganizationIntegrationsQuery +{ + public async Task> GetManyByOrganizationAsync(Guid organizationId) + { + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + return integrations.ToList(); + } +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..e7b79eab13 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs @@ -0,0 +1,18 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for creating an OrganizationIntegration. +/// +public interface ICreateOrganizationIntegrationCommand +{ + /// + /// Creates a new organization integration. + /// + /// The OrganizationIntegration to create. + /// The created OrganizationIntegration. + /// Thrown when an integration + /// of the same type already exists for the organization. + Task CreateAsync(OrganizationIntegration integration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..be22b4e482 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs @@ -0,0 +1,16 @@ +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for deleting organization integrations. +/// +public interface IDeleteOrganizationIntegrationCommand +{ + /// + /// Deletes an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration to delete. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + Task DeleteAsync(Guid organizationId, Guid integrationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs new file mode 100644 index 0000000000..8cdea7f301 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs @@ -0,0 +1,16 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Query interface for retrieving organization integrations. +/// +public interface IGetOrganizationIntegrationsQuery +{ + /// + /// Retrieves all organization integrations for a specific organization. + /// + /// The unique identifier of the organization. + /// A list of organization integrations associated with the organization. + Task> GetManyByOrganizationAsync(Guid organizationId); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..f40086600d --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs @@ -0,0 +1,20 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for updating organization integrations. +/// +public interface IUpdateOrganizationIntegrationCommand +{ + /// + /// Updates an existing organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration to update. + /// The updated organization integration data. + /// The updated organization integration. + /// Thrown when the integration does not exist, + /// does not belong to the specified organization, or the integration type does not match. + Task UpdateAsync(Guid organizationId, Guid integrationId, OrganizationIntegration updatedIntegration); +} diff --git a/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..12a8620926 --- /dev/null +++ b/src/Core/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs @@ -0,0 +1,45 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for updating organization integrations with cache invalidation support. +/// +public class UpdateOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] + IFusionCache cache) + : IUpdateOrganizationIntegrationCommand +{ + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + OrganizationIntegration updatedIntegration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration is null || + integration.OrganizationId != organizationId || + integration.Type != updatedIntegration.Type) + { + throw new NotFoundException(); + } + + updatedIntegration.Id = integration.Id; + updatedIntegration.OrganizationId = integration.OrganizationId; + updatedIntegration.CreationDate = integration.CreationDate; + await integrationRepository.ReplaceAsync(updatedIntegration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + + return updatedIntegration; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs index fe33c45156..c44e550d15 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs @@ -1,8 +1,8 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; -using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; @@ -36,13 +36,18 @@ public class IntegrationTemplateContext(EventMessage eventMessage) public string DateIso8601 => Date.ToString("o"); public string EventMessage => JsonSerializer.Serialize(Event); - public User? User { get; set; } + public OrganizationUserUserDetails? User { get; set; } public string? UserName => User?.Name; public string? UserEmail => User?.Email; + public OrganizationUserType? UserType => User?.Type; - public User? ActingUser { get; set; } + public OrganizationUserUserDetails? ActingUser { get; set; } public string? ActingUserName => ActingUser?.Name; public string? ActingUserEmail => ActingUser?.Email; + public OrganizationUserType? ActingUserType => ActingUser?.Type; + + public Group? Group { get; set; } + public string? GroupName => Group?.Name; public Organization? Organization { get; set; } public string? OrganizationName => Organization?.DisplayName(); diff --git a/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs index 820b65dbfd..0368678641 100644 --- a/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs @@ -53,4 +53,5 @@ public interface IProfileOrganizationDetails bool UseAdminSponsoredFamilies { get; set; } bool UseOrganizationDomains { get; set; } bool UseAutomaticUserConfirmation { get; set; } + bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs b/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs new file mode 100644 index 0000000000..0dc6d1c352 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.AdminConsole.Models.Data.OrganizationUsers; + +public record AcceptedOrganizationUserToConfirm +{ + public required Guid OrganizationUserId { get; init; } + public required Guid UserId { get; init; } + public required string Key { get; init; } +} diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs index 3c02a4f50b..7c8389c103 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs @@ -29,6 +29,7 @@ public class OrganizationAbility UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; + UsePhishingBlocker = organization.UsePhishingBlocker; } public Guid Id { get; set; } @@ -51,4 +52,5 @@ public class OrganizationAbility public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs index 8d30bfc250..00b9280337 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs @@ -65,4 +65,5 @@ public class OrganizationUserOrganizationDetails : IProfileOrganizationDetails public bool UseAdminSponsoredFamilies { get; set; } public bool? IsAdminInitiated { get; set; } public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs index 6d182e197f..00ba706a41 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs @@ -20,6 +20,12 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I public string Email { get; set; } public string AvatarColor { get; set; } public string TwoFactorProviders { get; set; } + /// + /// Indicates whether the user has a personal premium subscription. + /// Does not include premium access from organizations - + /// do not use this to check whether the user can access premium features. + /// Null when the organization user is in Invited status (UserId is null). + /// public bool? Premium { get; set; } public OrganizationUserStatusType Status { get; set; } public OrganizationUserType Type { get; set; } @@ -63,11 +69,6 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I return UserId; } - public bool GetPremium() - { - return Premium.GetValueOrDefault(false); - } - public Permissions GetPermissions() { return string.IsNullOrWhiteSpace(Permissions) ? null diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index 84ff164943..484320c271 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -154,6 +154,7 @@ public class SelfHostedOrganizationDetails : Organization Status = Status, UseRiskInsights = UseRiskInsights, UseAdminSponsoredFamilies = UseAdminSponsoredFamilies, + UsePhishingBlocker = UsePhishingBlocker, }; } } diff --git a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs index 0d48f5cfa9..dcec028dcc 100644 --- a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -56,4 +56,5 @@ public class ProviderUserOrganizationDetails : IProfileOrganizationDetails public string? SsoExternalId { get; set; } public string? Permissions { get; set; } public string? ResetPasswordKey { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs index a78dd95260..b9bad6a346 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -18,7 +19,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IGroupRepository _groupRepository; private readonly IEventService _eventService; private readonly IOrganizationService _organizationService; @@ -27,7 +28,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA public ImportOrganizationUsersAndGroupsCommand(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IGroupRepository groupRepository, IEventService eventService, IOrganizationService organizationService) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index 595e487580..e6cc3da2a2 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; @@ -25,8 +24,6 @@ public class VerifyOrganizationDomainCommand( IEventService eventService, IGlobalSettings globalSettings, ICurrentContext currentContext, - IFeatureService featureService, - ISavePolicyCommand savePolicyCommand, IVNextSavePolicyCommand vNextSavePolicyCommand, IMailService mailService, IOrganizationUserRepository organizationUserRepository, @@ -144,15 +141,8 @@ public class VerifyOrganizationDomainCommand( PerformedBy = actingUser }; - if (featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) - { - var savePolicyModel = new SavePolicyModel(policyUpdate, actingUser); - await vNextSavePolicyCommand.SaveAsync(savePolicyModel); - } - else - { - await savePolicyCommand.SaveAsync(policyUpdate); - } + var savePolicyModel = new SavePolicyModel(policyUpdate, actingUser); + await vNextSavePolicyCommand.SaveAsync(savePolicyModel); } private async Task SendVerifiedDomainUserEmailAsync(OrganizationDomain domain) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs index 63f177b3f3..c763cc0cc2 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; @@ -34,6 +35,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IFeatureService _featureService; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IAutomaticUserConfirmationPolicyEnforcementValidator _automaticUserConfirmationPolicyEnforcementValidator; public AcceptOrgUserCommand( IDataProtectionProvider dataProtectionProvider, @@ -46,7 +48,8 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IFeatureService featureService, - IPolicyRequirementQuery policyRequirementQuery) + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) { // TODO: remove data protector when old token validation removed _dataProtector = dataProtectionProvider.CreateProtector(OrgUserInviteTokenable.DataProtectorPurpose); @@ -60,6 +63,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _featureService = featureService; _policyRequirementQuery = policyRequirementQuery; + _automaticUserConfirmationPolicyEnforcementValidator = automaticUserConfirmationPolicyEnforcementValidator; } public async Task AcceptOrgUserByEmailTokenAsync(Guid organizationUserId, User user, string emailToken, @@ -186,13 +190,19 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand } } - // Enforce Single Organization Policy of organization user is trying to join var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + await ValidateAutomaticUserConfirmationPolicyAsync(orgUser, allOrgUsers, user); + } + + // Enforce Single Organization Policy of organization user is trying to join var invitedSingleOrgPolicies = await _policyService.GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg, OrganizationUserStatusType.Invited); - if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + if (allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId) + && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) { throw new BadRequestException("You may not join this organization until you leave or remove all other organizations."); } @@ -255,4 +265,20 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand } } } + + private async Task ValidateAutomaticUserConfirmationPolicyAsync(OrganizationUser orgUser, + ICollection allOrgUsers, User user) + { + var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, allOrgUsers, user))) + .Match( + error => error.Message, + _ => string.Empty + ); + + if (!string.IsNullOrEmpty(error)) + { + throw new BadRequestException(error); + } + } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs new file mode 100644 index 0000000000..67b5f0da80 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs @@ -0,0 +1,186 @@ +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using OneOf.Types; +using CommandResult = Bit.Core.AdminConsole.Utilities.v2.Results.CommandResult; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public class AutomaticallyConfirmOrganizationUserCommand(IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IAutomaticallyConfirmOrganizationUsersValidator validator, + IEventService eventService, + IMailService mailService, + IUserRepository userRepository, + IPushRegistrationService pushRegistrationService, + IDeviceRepository deviceRepository, + IPushNotificationService pushNotificationService, + IPolicyRequirementQuery policyRequirementQuery, + ICollectionRepository collectionRepository, + TimeProvider timeProvider, + ILogger logger) : IAutomaticallyConfirmOrganizationUserCommand +{ + public async Task AutomaticallyConfirmOrganizationUserAsync(AutomaticallyConfirmOrganizationUserRequest request) + { + var validatorRequest = await RetrieveDataAsync(request); + + var validatedData = await validator.ValidateAsync(validatorRequest); + + return await validatedData.Match>( + error => Task.FromResult(new CommandResult(error)), + async _ => + { + var userToConfirm = new AcceptedOrganizationUserToConfirm + { + OrganizationUserId = validatedData.Request.OrganizationUser!.Id, + UserId = validatedData.Request.OrganizationUser.UserId!.Value, + Key = validatedData.Request.Key + }; + + // This operation is idempotent. If false, the user is already confirmed and no additional side effects are required. + if (!await organizationUserRepository.ConfirmOrganizationUserAsync(userToConfirm)) + { + return new None(); + } + + await CreateDefaultCollectionsAsync(validatedData.Request); + + await Task.WhenAll( + LogOrganizationUserConfirmedEventAsync(validatedData.Request), + SendConfirmedOrganizationUserEmailAsync(validatedData.Request), + SyncOrganizationKeysAsync(validatedData.Request) + ); + + return new None(); + } + ); + } + + private async Task SyncOrganizationKeysAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + await DeleteDeviceRegistrationAsync(request); + await PushSyncOrganizationKeysAsync(request); + } + + private async Task CreateDefaultCollectionsAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + if (!await ShouldCreateDefaultCollectionAsync(request)) + { + return; + } + + await collectionRepository.CreateAsync( + new Collection + { + OrganizationId = request.Organization!.Id, + Name = request.DefaultUserCollectionName, + Type = CollectionType.DefaultUserCollection + }, + groups: null, + [new CollectionAccessSelection + { + Id = request.OrganizationUser!.Id, + Manage = true + }]); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to create default collection for user."); + } + } + + /// + /// Determines whether a default collection should be created for an organization user during the confirmation process. + /// + /// + /// The validation request containing information about the user, organization, and collection settings. + /// + /// The result is a boolean value indicating whether a default collection should be created. + private async Task ShouldCreateDefaultCollectionAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) => + !string.IsNullOrWhiteSpace(request.DefaultUserCollectionName) + && (await policyRequirementQuery.GetAsync(request.OrganizationUser!.UserId!.Value)) + .RequiresDefaultCollectionOnConfirm(request.Organization!.Id); + + private async Task PushSyncOrganizationKeysAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + await pushNotificationService.PushSyncOrgKeysAsync(request.OrganizationUser!.UserId!.Value); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to push organization keys."); + } + } + + private async Task LogOrganizationUserConfirmedEventAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + await eventService.LogOrganizationUserEventAsync(request.OrganizationUser, + EventType.OrganizationUser_AutomaticallyConfirmed, + timeProvider.GetUtcNow().UtcDateTime); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to log OrganizationUser_AutomaticallyConfirmed event."); + } + } + + private async Task SendConfirmedOrganizationUserEmailAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + var user = await userRepository.GetByIdAsync(request.OrganizationUser!.UserId!.Value); + + await mailService.SendOrganizationConfirmedEmailAsync(request.Organization!.Name, + user!.Email, + request.OrganizationUser.AccessSecretsManager); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to send OrganizationUserConfirmed."); + } + } + + private async Task DeleteDeviceRegistrationAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + var devices = (await deviceRepository.GetManyByUserIdAsync(request.OrganizationUser!.UserId!.Value)) + .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) + .Select(d => d.Id.ToString()); + + await pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices, request.Organization!.Id.ToString()); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to delete device registration."); + } + } + + private async Task RetrieveDataAsync( + AutomaticallyConfirmOrganizationUserRequest request) + { + return new AutomaticallyConfirmOrganizationUserValidationRequest + { + OrganizationUserId = request.OrganizationUserId, + OrganizationId = request.OrganizationId, + Key = request.Key, + DefaultUserCollectionName = request.DefaultUserCollectionName, + PerformedBy = request.PerformedBy, + OrganizationUser = await organizationUserRepository.GetByIdAsync(request.OrganizationUserId), + Organization = await organizationRepository.GetByIdAsync(request.OrganizationId) + }; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs new file mode 100644 index 0000000000..fcc8dacf66 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs @@ -0,0 +1,29 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +/// +/// Automatically Confirm User Command Request +/// +public record AutomaticallyConfirmOrganizationUserRequest +{ + public required Guid OrganizationUserId { get; init; } + public required Guid OrganizationId { get; init; } + public required string Key { get; init; } + public required string DefaultUserCollectionName { get; init; } + public required IActingUser PerformedBy { get; init; } +} + +/// +/// Automatically Confirm User Validation Request +/// +/// +/// This is used to hold retrieved data and pass it to the validator +/// +public record AutomaticallyConfirmOrganizationUserValidationRequest : AutomaticallyConfirmOrganizationUserRequest +{ + public OrganizationUser? OrganizationUser { get; set; } + public Organization? Organization { get; set; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs new file mode 100644 index 0000000000..3375120516 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs @@ -0,0 +1,125 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public class AutomaticallyConfirmOrganizationUsersValidator( + IOrganizationUserRepository organizationUserRepository, + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator, + IUserService userService, + IPolicyRepository policyRepository) : IAutomaticallyConfirmOrganizationUsersValidator +{ + public async Task> ValidateAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request) + { + // User must exist + if (request is { OrganizationUser: null } || request.OrganizationUser is { UserId: null }) + { + return Invalid(request, new UserNotFoundError()); + } + + // Organization must exist + if (request is { Organization: null }) + { + return Invalid(request, new OrganizationNotFound()); + } + + // User must belong to the organization + if (request.OrganizationUser.OrganizationId != request.Organization.Id) + { + return Invalid(request, new OrganizationUserIdIsInvalid()); + } + + // User must be accepted + if (request is { OrganizationUser.Status: not OrganizationUserStatusType.Accepted }) + { + return Invalid(request, new UserIsNotAccepted()); + } + + // User must be of type User + if (request is { OrganizationUser.Type: not OrganizationUserType.User }) + { + return Invalid(request, new UserIsNotUserType()); + } + + if (!await OrganizationHasAutomaticallyConfirmUsersPolicyEnabledAsync(request)) + { + return Invalid(request, new AutomaticallyConfirmUsersPolicyIsNotEnabled()); + } + + if (!await OrganizationUserConformsToTwoFactorRequiredPolicyAsync(request)) + { + return Invalid(request, new UserDoesNotHaveTwoFactorEnabled()); + } + + if (await OrganizationUserConformsToAutomaticUserConfirmationPolicyAsync(request) is { } error) + { + return Invalid(request, error); + } + + return Valid(request); + } + + private async Task OrganizationHasAutomaticallyConfirmUsersPolicyEnabledAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) => + await policyRepository.GetByOrganizationIdTypeAsync(request.OrganizationId, PolicyType.AutomaticUserConfirmation) is { Enabled: true } + && request.Organization is { UseAutomaticUserConfirmation: true }; + + private async Task OrganizationUserConformsToTwoFactorRequiredPolicyAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + if ((await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync([request.OrganizationUser!.UserId!.Value])) + .Any(x => x.userId == request.OrganizationUser.UserId && x.twoFactorIsEnabled)) + { + return true; + } + + return !(await policyRequirementQuery.GetAsync(request.OrganizationUser.UserId!.Value)) + .IsTwoFactorRequiredForOrganization(request.Organization!.Id); + } + + /// + /// Validates whether the specified organization user complies with the automatic user confirmation policy. + /// This includes checks across all organizations the user is associated with to ensure they meet the compliance criteria. + /// + /// We are not checking single organization policy compliance here because automatically confirm users policy enforces + /// a stricter version and applies to all users. If you are compliant with Auto Confirm, you'll be in compliance with + /// Single Org. + /// + /// + /// The request model encapsulates the current organization, the user being validated, and all organization users associated + /// with that user. + /// + /// + /// An if the user fails to meet the automatic user confirmation policy, or null if the validation succeeds. + /// + private async Task OrganizationUserConformsToAutomaticUserConfirmationPolicyAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request) + { + var allOrganizationUsersForUser = await organizationUserRepository + .GetManyByUserAsync(request.OrganizationUser!.UserId!.Value); + + var user = await userService.GetUserByIdAsync(request.OrganizationUser!.UserId!.Value); + + return (await automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest( + request.OrganizationId, + allOrganizationUsersForUser, + user))) + .Match( + error => error, + _ => null + ); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs new file mode 100644 index 0000000000..e65db00f73 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs @@ -0,0 +1,16 @@ +using Bit.Core.AdminConsole.Utilities.v2; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public record OrganizationNotFound() : NotFoundError("Invalid organization"); +public record FailedToWriteToEventLog() : InternalError("Failed to write to event log"); +public record UserIsNotUserType() : BadRequestError("Only organization users with the User role can be automatically confirmed"); +public record UserIsNotAccepted() : BadRequestError("Cannot confirm user that has not accepted the invitation."); +public record OrganizationUserIdIsInvalid() : BadRequestError("Invalid organization user id."); +public record UserDoesNotHaveTwoFactorEnabled() : BadRequestError("User does not have two-step login enabled."); +public record UserCannotBelongToAnotherOrganization() : BadRequestError("Cannot confirm this member to the organization until they leave or remove all other organizations"); +public record OtherOrganizationDoesNotAllowOtherMembership() : BadRequestError("Cannot confirm this member to the organization because they are in another organization which forbids it."); +public record AutomaticallyConfirmUsersPolicyIsNotEnabled() : BadRequestError("Cannot confirm this member because the Automatically Confirm Users policy is not enabled."); +public record ProviderUsersCannotJoin() : BadRequestError("An organization the user is a part of has enabled Automatic User Confirmation policy, and it does not support provider users joining."); +public record UserCannotJoinProvider() : BadRequestError("An organization the user is a part of has enabled Automatic User Confirmation policy, and it does not support the user joining a provider."); +public record CurrentOrganizationUserIsNotPresentInRequest() : BadRequestError("The current organization user does not exist in the request."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs new file mode 100644 index 0000000000..544b65b53f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs @@ -0,0 +1,9 @@ +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public interface IAutomaticallyConfirmOrganizationUsersValidator +{ + Task> ValidateAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md new file mode 100644 index 0000000000..063b2f6a5c --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md @@ -0,0 +1,22 @@ +# Automatic User Confirmation + +Owned by: admin-console + +Automatic confirmation requests are server driven events that are sent to the admin's client where via a background service the confirmation will occur. The basic model +for the workflow is as follows: + +- The Api server sends an invite email to a user. +- The user accepts the invite request, which is sent back to the Api server +- The Api server sends a push-notification with the OrganizationId and UserId to a client admin session. +- The Client performs the key exchange in the background and POSTs the ConfirmRequest back to the Api server +- The Api server runs the OrgUser_Confirm sproc to confirm the user in the DB + +This Feature has the following security measures in place in order to achieve our security goals: + +- The single organization exemption for admins/owners is removed for this policy. + - This is enforced by preventing enabling the policy and organization plan feature if there are non-compliant users +- Emergency access is removed for all organization users +- Automatic confirmation will only apply to the User role (You cannot auto confirm admins/owners to an organization) +- The organization has no members with the Provider user type. + - This will also prevent the policy and organization plan feature from being enabled + - This will prevent sending organization invites to provider users diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 2fbe6be5c6..b6b49e93e9 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -33,6 +34,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IFeatureService _featureService; private readonly ICollectionRepository _collectionRepository; + private readonly IAutomaticUserConfirmationPolicyEnforcementValidator _automaticUserConfirmationPolicyEnforcementValidator; public ConfirmOrganizationUserCommand( IOrganizationRepository organizationRepository, @@ -47,7 +49,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand IDeviceRepository deviceRepository, IPolicyRequirementQuery policyRequirementQuery, IFeatureService featureService, - ICollectionRepository collectionRepository) + ICollectionRepository collectionRepository, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -62,6 +65,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand _policyRequirementQuery = policyRequirementQuery; _featureService = featureService; _collectionRepository = collectionRepository; + _automaticUserConfirmationPolicyEnforcementValidator = automaticUserConfirmationPolicyEnforcementValidator; } public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, @@ -127,6 +131,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand var organization = await _organizationRepository.GetByIdAsync(organizationId); var allUsersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validSelectedUserIds); + var users = await _userRepository.GetManyAsync(validSelectedUserIds); var usersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(validSelectedUserIds); @@ -188,6 +193,25 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand await ValidateTwoFactorAuthenticationPolicyAsync(user, organizationId, userTwoFactorEnabled); var hasOtherOrgs = userOrgs.Any(ou => ou.OrganizationId != organizationId); + + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationId, + userOrgs, + user))) + .Match( + error => new BadRequestException(error.Message), + _ => null + ); + + if (error is not null) + { + throw error; + } + } + var singleOrgPolicies = await _policyService.GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg); var otherSingleOrgPolicies = singleOrgPolicies.Where(p => p.OrganizationId != organizationId); @@ -267,8 +291,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var organizationDataOwnershipPolicy = - await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value); + var organizationDataOwnershipPolicy = await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value); if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId)) { return; @@ -311,8 +334,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var policyEligibleOrganizationUserIds = - await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId); + var policyEligibleOrganizationUserIds = await _policyRequirementQuery + .GetManyByOrganizationIdAsync(organizationId); var eligibleOrganizationUserIds = confirmedOrganizationUsers .Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id)) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs index 87c24c3ab4..c5c423f2bb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs @@ -1,4 +1,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs index 315d45ea69..71eff3ae69 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Repositories; -using static Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount.ValidationResultHelpers; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs index 6c8f7ee00c..a76104cc88 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs @@ -1,15 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2; -/// -/// A strongly typed error containing a reason that an action failed. -/// This is used for business logic validation and other expected errors, not exceptions. -/// -public abstract record Error(string Message); -/// -/// An type that maps to a NotFoundResult at the api layer. -/// -/// -public abstract record NotFoundError(string Message) : Error(Message); +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public record UserNotFoundError() : NotFoundError("Invalid user."); public record UserNotClaimedError() : Error("Member is not claimed by the organization."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs index 983a3a4f21..408d3e8bcd 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public interface IDeleteClaimedOrganizationUserAccountCommand { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs index f1a2c71b1b..05e97e896a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public interface IDeleteClaimedOrganizationUserAccountValidator { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs new file mode 100644 index 0000000000..a1776416ae --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs @@ -0,0 +1,40 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; + +/// +/// Command to automatically confirm an organization user. +/// +/// +/// The auto-confirm feature enables eligible client apps to confirm OrganizationUsers +/// automatically via push notifications, eliminating the need for manual administrator +/// intervention. Client apps receive a push notification, perform the required key exchange, +/// and submit an auto-confirm request to the server. This command processes those +/// client-initiated requests and should only be used in that specific context. +/// +public interface IAutomaticallyConfirmOrganizationUserCommand +{ + /// + /// Automatically confirms the organization user based on the provided request data. + /// + /// The request containing necessary information to confirm the organization user. + /// + /// This action has side effects. The side effects are + ///
    + ///
  • Creating an event log entry.
  • + ///
  • Syncing organization keys with the user.
  • + ///
  • Deleting any registered user devices for the organization.
  • + ///
  • Sending an email to the confirmed user.
  • + ///
  • Creating the default collection if applicable.
  • + ///
+ /// + /// Each of these actions is performed independently of each other and not guaranteed to be performed in any order. + /// Errors will be reported back for the actions that failed in a consolidated error message. + ///
+ /// + /// The result of the command. If there was an error, the result will contain a typed error describing the problem + /// that occurred. + /// + Task AutomaticallyConfirmOrganizationUserAsync(AutomaticallyConfirmOrganizationUserRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs new file mode 100644 index 0000000000..c7c80bd937 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs @@ -0,0 +1,69 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.AdminConsole.Utilities.DebuggingInstruments; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +public class BulkResendOrganizationInvitesCommand : IBulkResendOrganizationInvitesCommand +{ + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; + private readonly ILogger _logger; + + public BulkResendOrganizationInvitesCommand( + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, + ILogger logger) + { + _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; + _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; + _logger = logger; + } + + public async Task>> BulkResendInvitesAsync( + Guid organizationId, + Guid? invitingUserId, + IEnumerable organizationUsersId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + _logger.LogUserInviteStateDiagnostics(orgUsers); + + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null) + { + throw new NotFoundException(); + } + + var validUsers = new List(); + var result = new List>(); + + foreach (var orgUser in orgUsers) + { + if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) + { + result.Add(Tuple.Create(orgUser, "User invalid.")); + } + else + { + validUsers.Add(orgUser); + } + } + + if (validUsers.Any()) + { + await _sendOrganizationInvitesCommand.SendInvitesAsync( + new SendInvitesRequest(validUsers, org)); + + result.AddRange(validUsers.Select(u => Tuple.Create(u, ""))); + } + + return result; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs new file mode 100644 index 0000000000..342a06fcf9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs @@ -0,0 +1,20 @@ +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +public interface IBulkResendOrganizationInvitesCommand +{ + /// + /// Resend invites to multiple organization users in bulk. + /// + /// The ID of the organization. + /// The ID of the user who is resending the invites. + /// The IDs of the organization users to resend invites to. + /// A tuple containing the OrganizationUser and an error message (empty string if successful) + Task>> BulkResendInvitesAsync( + Guid organizationId, + Guid? invitingUserId, + IEnumerable organizationUsersId); +} + + diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs index f8bd988cab..2648a2e429 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs @@ -2,10 +2,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; using Bit.Core.AdminConsole.Utilities.Errors; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; @@ -15,7 +15,7 @@ public class InviteOrganizationUsersValidator( IOrganizationRepository organizationRepository, IInviteUsersPasswordManagerValidator inviteUsersPasswordManagerValidator, IUpdateSecretsManagerSubscriptionCommand secretsManagerSubscriptionCommand, - IPaymentService paymentService) : IInviteUsersValidator + IStripePaymentService paymentService) : IInviteUsersValidator { public async Task> ValidateAsync( InviteOrganizationUsersValidationRequest request) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs index 67155fe91a..9ba2fd1596 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs @@ -9,8 +9,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; @@ -22,7 +22,7 @@ public class InviteUsersPasswordManagerValidator( IInviteUsersEnvironmentValidator inviteUsersEnvironmentValidator, IInviteUsersOrganizationValidator inviteUsersOrganizationValidator, IProviderRepository providerRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository ) : IInviteUsersPasswordManagerValidator { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index 651a9225b4..ec42c8b402 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -29,7 +30,8 @@ public class RestoreOrganizationUserCommand( IUserRepository userRepository, IOrganizationService organizationService, IFeatureService featureService, - IPolicyRequirementQuery policyRequirementQuery) : IRestoreOrganizationUserCommand + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) : IRestoreOrganizationUserCommand { public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId) { @@ -300,6 +302,25 @@ public class RestoreOrganizationUserCommand( { throw new BadRequestException(user.Email + " is not compliant with the two-step login policy"); } + + if (featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var validationResult = await automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, + allOrgUsers, + user!)); + + var badRequestException = validationResult.Match( + error => new BadRequestException(user.Email + + " is not compliant with the automatic user confirmation policy: " + + error.Message), + _ => null); + + if (badRequestException is not null) + { + throw badRequestException; + } + } } private async Task IsTwoFactorRequiredForOrganizationAsync(Guid userId, Guid organizationId) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs similarity index 95% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs rename to src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs index 01ad2f05d2..7b5541c3ce 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs @@ -1,7 +1,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; public interface IRevokeOrganizationUserCommand { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs similarity index 99% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs rename to src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs index f24e0ae265..7aa67f0813 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs @@ -7,7 +7,7 @@ using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; public class RevokeOrganizationUserCommand( IEventService eventService, diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs new file mode 100644 index 0000000000..a30894c7d5 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs @@ -0,0 +1,8 @@ +using Bit.Core.AdminConsole.Utilities.v2; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public record UserAlreadyRevoked() : BadRequestError("Already revoked."); +public record CannotRevokeYourself() : BadRequestError("You cannot revoke yourself."); +public record OnlyOwnersCanRevokeOwners() : BadRequestError("Only owners can revoke other owners."); +public record MustHaveConfirmedOwner() : BadRequestError("Organization must have at least one confirmed owner."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs new file mode 100644 index 0000000000..e6471ad891 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs @@ -0,0 +1,8 @@ +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public interface IRevokeOrganizationUserCommand +{ + Task> RevokeUsersAsync(RevokeOrganizationUsersRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs new file mode 100644 index 0000000000..1a5cfd2c46 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs @@ -0,0 +1,9 @@ +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public interface IRevokeOrganizationUserValidator +{ + Task>> ValidateAsync(RevokeOrganizationUsersValidationRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs new file mode 100644 index 0000000000..ca501277a7 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs @@ -0,0 +1,114 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public class RevokeOrganizationUserCommand( + IOrganizationUserRepository organizationUserRepository, + IEventService eventService, + IPushNotificationService pushNotificationService, + IRevokeOrganizationUserValidator validator, + TimeProvider timeProvider, + ILogger logger) + : IRevokeOrganizationUserCommand +{ + public async Task> RevokeUsersAsync(RevokeOrganizationUsersRequest request) + { + var validationRequest = await CreateValidationRequestsAsync(request); + + var results = await validator.ValidateAsync(validationRequest); + + var validUsers = results.Where(r => r.IsValid).Select(r => r.Request).ToList(); + + await RevokeValidUsersAsync(validUsers); + + await Task.WhenAll( + LogRevokedOrganizationUsersAsync(validUsers, request.PerformedBy), + SendPushNotificationsAsync(validUsers) + ); + + return results.Select(r => r.Match( + error => new BulkCommandResult(r.Request.Id, error), + _ => new BulkCommandResult(r.Request.Id, new None()) + )); + } + + private async Task CreateValidationRequestsAsync( + RevokeOrganizationUsersRequest request) + { + var organizationUserToRevoke = await organizationUserRepository + .GetManyAsync(request.OrganizationUserIdsToRevoke); + + return new RevokeOrganizationUsersValidationRequest( + request.OrganizationId, + request.OrganizationUserIdsToRevoke, + request.PerformedBy, + organizationUserToRevoke); + } + + private async Task RevokeValidUsersAsync(ICollection validUsers) + { + if (validUsers.Count == 0) + { + return; + } + + await organizationUserRepository.RevokeManyByIdAsync(validUsers.Select(u => u.Id)); + } + + private async Task LogRevokedOrganizationUsersAsync( + ICollection revokedUsers, + IActingUser actingUser) + { + if (revokedUsers.Count == 0) + { + return; + } + + var eventDate = timeProvider.GetUtcNow().UtcDateTime; + + if (actingUser is SystemUser { SystemUserType: not null }) + { + var revokeEventsWithSystem = revokedUsers + .Select(user => (user, EventType.OrganizationUser_Revoked, actingUser.SystemUserType!.Value, + (DateTime?)eventDate)) + .ToList(); + await eventService.LogOrganizationUserEventsAsync(revokeEventsWithSystem); + } + else + { + var revokeEvents = revokedUsers + .Select(user => (user, EventType.OrganizationUser_Revoked, (DateTime?)eventDate)) + .ToList(); + await eventService.LogOrganizationUserEventsAsync(revokeEvents); + } + } + + private async Task SendPushNotificationsAsync(ICollection revokedUsers) + { + var userIdsToNotify = revokedUsers + .Where(user => user.UserId.HasValue) + .Select(user => user.UserId!.Value) + .Distinct() + .ToList(); + + foreach (var userId in userIdsToNotify) + { + try + { + await pushNotificationService.PushSyncOrgKeysAsync(userId); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to send push notification for user {UserId}.", userId); + } + } + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs new file mode 100644 index 0000000000..56996ffb53 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public record RevokeOrganizationUsersRequest( + Guid OrganizationId, + ICollection OrganizationUserIdsToRevoke, + IActingUser PerformedBy +); + +public record RevokeOrganizationUsersValidationRequest( + Guid OrganizationId, + ICollection OrganizationUserIdsToRevoke, + IActingUser PerformedBy, + ICollection OrganizationUsersToRevoke +) : RevokeOrganizationUsersRequest(OrganizationId, OrganizationUserIdsToRevoke, PerformedBy); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs new file mode 100644 index 0000000000..d2f47ed713 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs @@ -0,0 +1,39 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public class RevokeOrganizationUsersValidator(IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery) + : IRevokeOrganizationUserValidator +{ + public async Task>> ValidateAsync( + RevokeOrganizationUsersValidationRequest request) + { + var hasRemainingOwner = await hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(request.OrganizationId, + request.OrganizationUsersToRevoke.Select(x => x.Id) // users excluded because they are going to be revoked + ); + + return request.OrganizationUsersToRevoke.Select(x => + { + return x switch + { + _ when request.PerformedBy is not SystemUser + && x.UserId is not null + && x.UserId == request.PerformedBy.UserId => + Invalid(x, new CannotRevokeYourself()), + { Status: OrganizationUserStatusType.Revoked } => + Invalid(x, new UserAlreadyRevoked()), + { Type: OrganizationUserType.Owner } when !hasRemainingOwner => + Invalid(x, new MustHaveConfirmedOwner()), + { Type: OrganizationUserType.Owner } when !request.PerformedBy.IsOrganizationOwnerOrProvider => + Invalid(x, new OnlyOwnersCanRevokeOwners()), + + _ => Valid(x) + }; + }).ToList(); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs index 8d8ab8cdfc..7f24c4acd7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs @@ -3,11 +3,14 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -33,7 +36,7 @@ public interface ICloudOrganizationSignUpCommand public class CloudOrganizationSignUpCommand( IOrganizationUserRepository organizationUserRepository, IOrganizationBillingService organizationBillingService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyService policyService, IOrganizationRepository organizationRepository, IOrganizationApiKeyRepository organizationApiKeyRepository, @@ -42,7 +45,9 @@ public class CloudOrganizationSignUpCommand( IPushNotificationService pushNotificationService, ICollectionRepository collectionRepository, IDeviceRepository deviceRepository, - IPricingClient pricingClient) : ICloudOrganizationSignUpCommand + IPricingClient pricingClient, + IPolicyRequirementQuery policyRequirementQuery, + IFeatureService featureService) : ICloudOrganizationSignUpCommand { public async Task SignUpOrganizationAsync(OrganizationSignup signup) { @@ -75,8 +80,7 @@ public class CloudOrganizationSignUpCommand( PlanType = plan!.Type, Seats = (short)(plan.PasswordManager.BaseSeats + signup.AdditionalSeats), MaxCollections = plan.PasswordManager.MaxCollections, - MaxStorageGb = !plan.PasswordManager.BaseStorageGb.HasValue ? - (short?)null : (short)(plan.PasswordManager.BaseStorageGb.Value + signup.AdditionalStorageGb), + MaxStorageGb = (short)(plan.PasswordManager.BaseStorageGb + signup.AdditionalStorageGb), UsePolicies = plan.HasPolicies, UseSso = plan.HasSso, UseGroups = plan.HasGroups, @@ -237,6 +241,17 @@ public class CloudOrganizationSignUpCommand( private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs index 6474914b48..da678ece71 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs @@ -2,6 +2,8 @@ #nullable disable using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Entities; @@ -28,6 +30,8 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand private readonly IGlobalSettings _globalSettings; private readonly IPolicyService _policyService; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IFeatureService _featureService; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public InitPendingOrganizationCommand( IOrganizationService organizationService, @@ -37,7 +41,9 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand IDataProtectionProvider dataProtectionProvider, IGlobalSettings globalSettings, IPolicyService policyService, - IOrganizationUserRepository organizationUserRepository + IOrganizationUserRepository organizationUserRepository, + IFeatureService featureService, + IPolicyRequirementQuery policyRequirementQuery ) { _organizationService = organizationService; @@ -48,6 +54,8 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand _globalSettings = globalSettings; _policyService = policyService; _organizationUserRepository = organizationUserRepository; + _featureService = featureService; + _policyRequirementQuery = policyRequirementQuery; } public async Task InitPendingOrganizationAsync(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken) @@ -113,6 +121,17 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await _policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs new file mode 100644 index 0000000000..85fbcd2740 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs @@ -0,0 +1,15 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; + +public interface IOrganizationUpdateCommand +{ + /// + /// Updates an organization's information in the Bitwarden database and Stripe (if required). + /// Also optionally updates an organization's public-private keypair if it was not created with one. + /// On self-host, only the public-private keys will be updated because all other properties are fixed by the license file. + /// + /// The update request containing the details to be updated. + Task UpdateAsync(OrganizationUpdateRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs index 6a81130402..f73c49c811 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,13 +13,13 @@ public class OrganizationDeleteCommand : IOrganizationDeleteCommand { private readonly IApplicationCacheService _applicationCacheService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; public OrganizationDeleteCommand( IApplicationCacheService applicationCacheService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ISsoConfigRepository ssoConfigRepository) { _applicationCacheService = applicationCacheService; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs index 27e70fbe2d..4a8f08a4f7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs @@ -73,7 +73,7 @@ public class ProviderClientOrganizationSignUpCommand : IProviderClientOrganizati PlanType = plan!.Type, Seats = signup.AdditionalSeats, MaxCollections = plan.PasswordManager.MaxCollections, - MaxStorageGb = 1, + MaxStorageGb = plan.PasswordManager.BaseStorageGb, UsePolicies = plan.HasPolicies, UseSso = plan.HasSso, UseOrganizationDomains = plan.HasOrganizationDomains, diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs index 446d7339ca..82260aa6a7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -39,7 +40,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; public ResellerClientOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -48,7 +49,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati IOrganizationUserRepository organizationUserRepository, IEventService eventService, ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, - IPaymentService paymentService) + IStripePaymentService paymentService) { _organizationRepository = organizationRepository; _organizationApiKeyRepository = organizationApiKeyRepository; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs index c52b7c10c9..9abce991c3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs @@ -2,6 +2,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; @@ -30,7 +32,9 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp private readonly ILicensingService _licensingService; private readonly IPolicyService _policyService; private readonly IGlobalSettings _globalSettings; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; + private readonly IFeatureService _featureService; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public SelfHostedOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -44,7 +48,9 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp ILicensingService licensingService, IPolicyService policyService, IGlobalSettings globalSettings, - IPaymentService paymentService) + IStripePaymentService paymentService, + IFeatureService featureService, + IPolicyRequirementQuery policyRequirementQuery) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -58,6 +64,8 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp _policyService = policyService; _globalSettings = globalSettings; _paymentService = paymentService; + _featureService = featureService; + _policyRequirementQuery = policyRequirementQuery; } public async Task<(Organization organization, OrganizationUser? organizationUser)> SignUpAsync( @@ -103,6 +111,17 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await _policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs new file mode 100644 index 0000000000..83318fd1e6 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs @@ -0,0 +1,77 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +public class OrganizationUpdateCommand( + IOrganizationService organizationService, + IOrganizationRepository organizationRepository, + IGlobalSettings globalSettings, + IOrganizationBillingService organizationBillingService +) : IOrganizationUpdateCommand +{ + public async Task UpdateAsync(OrganizationUpdateRequest request) + { + var organization = await organizationRepository.GetByIdAsync(request.OrganizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (globalSettings.SelfHosted) + { + return await UpdateSelfHostedAsync(organization, request); + } + + return await UpdateCloudAsync(organization, request); + } + + private async Task UpdateCloudAsync(Organization organization, OrganizationUpdateRequest request) + { + // Store original values for comparison + var originalName = organization.Name; + var originalBillingEmail = organization.BillingEmail; + + // Apply updates to organization + organization.UpdateDetails(request); + organization.BackfillPublicPrivateKeys(request); + await organizationService.ReplaceAndUpdateCacheAsync(organization, EventType.Organization_Updated); + + // Update billing information in Stripe if required + await UpdateBillingAsync(organization, originalName, originalBillingEmail); + + return organization; + } + + /// + /// Self-host cannot update the organization details because they are set by the license file. + /// However, this command does offer a soft migration pathway for organizations without public and private keys. + /// If we remove this migration code in the future, this command and endpoint can become cloud only. + /// + private async Task UpdateSelfHostedAsync(Organization organization, OrganizationUpdateRequest request) + { + organization.BackfillPublicPrivateKeys(request); + await organizationService.ReplaceAndUpdateCacheAsync(organization, EventType.Organization_Updated); + return organization; + } + + private async Task UpdateBillingAsync(Organization organization, string originalName, string? originalBillingEmail) + { + // Update Stripe if name or billing email changed + var shouldUpdateBilling = originalName != organization.Name || + originalBillingEmail != organization.BillingEmail; + + if (!shouldUpdateBilling) + { + return; + } + + await organizationBillingService.UpdateOrganizationNameAndEmail(organization); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateExtensions.cs new file mode 100644 index 0000000000..e90c39bc54 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateExtensions.cs @@ -0,0 +1,43 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +public static class OrganizationUpdateExtensions +{ + /// + /// Updates the organization name and/or billing email. + /// Any null property on the request object will be skipped. + /// + public static void UpdateDetails(this Organization organization, OrganizationUpdateRequest request) + { + // These values may or may not be sent by the client depending on the operation being performed. + // Skip any values not provided. + if (request.Name is not null) + { + organization.Name = request.Name; + } + + if (request.BillingEmail is not null) + { + organization.BillingEmail = request.BillingEmail.ToLowerInvariant().Trim(); + } + } + + /// + /// Updates the organization public and private keys if provided and not already set. + /// This is legacy code for old organizations that were not created with a public/private keypair. It is a soft + /// migration that will silently migrate organizations when they change their details. + /// + public static void BackfillPublicPrivateKeys(this Organization organization, OrganizationUpdateRequest request) + { + if (!string.IsNullOrWhiteSpace(request.PublicKey) && string.IsNullOrWhiteSpace(organization.PublicKey)) + { + organization.PublicKey = request.PublicKey; + } + + if (!string.IsNullOrWhiteSpace(request.EncryptedPrivateKey) && string.IsNullOrWhiteSpace(organization.PrivateKey)) + { + organization.PrivateKey = request.EncryptedPrivateKey; + } + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs new file mode 100644 index 0000000000..21d4948678 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs @@ -0,0 +1,33 @@ +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +/// +/// Request model for updating the name, billing email, and/or public-private keys for an organization (legacy migration code). +/// Any combination of these properties can be updated, so they are optional. If none are specified it will not update anything. +/// +public record OrganizationUpdateRequest +{ + /// + /// The ID of the organization to update. + /// + public required Guid OrganizationId { get; init; } + + /// + /// The new organization name to apply (optional, this is skipped if not provided). + /// + public string? Name { get; init; } + + /// + /// The new billing email address to apply (optional, this is skipped if not provided). + /// + public string? BillingEmail { get; init; } + + /// + /// The organization's public key to set (optional, only set if not already present on the organization). + /// + public string? PublicKey { get; init; } + + /// + /// The organization's encrypted private key to set (optional, only set if not already present on the organization). + /// + public string? EncryptedPrivateKey { get; init; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs index 450f425bdf..e4d5a94c4c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs @@ -1,12 +1,12 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Microsoft.Extensions.Logging; namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations; -public class UpdateOrganizationSubscriptionCommand(IPaymentService paymentService, +public class UpdateOrganizationSubscriptionCommand(IStripePaymentService paymentService, IOrganizationRepository repository, TimeProvider timeProvider, ILogger logger) : IUpdateOrganizationSubscriptionCommand diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs new file mode 100644 index 0000000000..962da4bef7 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs @@ -0,0 +1,44 @@ +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +/// +/// Request object for +/// +public record AutomaticUserConfirmationPolicyEnforcementRequest +{ + /// + /// Organization to be validated + /// + public Guid OrganizationId { get; } + + /// + /// All organization users that match the provided user. + /// + public ICollection AllOrganizationUsers { get; } + + /// + /// User associated with the organization user to be confirmed + /// + public User User { get; } + + /// + /// Request object for . + /// + /// + /// This record is used to encapsulate the data required for handling the automatic confirmation policy enforcement. + /// + /// The organization to be validated. + /// All organization users that match the provided user. + /// The user entity connecting all org users provided. + public AutomaticUserConfirmationPolicyEnforcementRequest( + Guid organizationId, + IEnumerable organizationUsers, + User user) + { + OrganizationId = organizationId; + AllOrganizationUsers = organizationUsers.ToArray(); + User = user; + } +} + diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs new file mode 100644 index 0000000000..633b84d2b9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs @@ -0,0 +1,49 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +public class AutomaticUserConfirmationPolicyEnforcementValidator( + IPolicyRequirementQuery policyRequirementQuery, + IProviderUserRepository providerUserRepository) + : IAutomaticUserConfirmationPolicyEnforcementValidator +{ + public async Task> IsCompliantAsync( + AutomaticUserConfirmationPolicyEnforcementRequest request) + { + var automaticUserConfirmationPolicyRequirement = await policyRequirementQuery + .GetAsync(request.User.Id); + + var currentOrganizationUser = request.AllOrganizationUsers + .FirstOrDefault(x => x.OrganizationId == request.OrganizationId + && x.UserId == request.User.Id); + + if (currentOrganizationUser is null) + { + return Invalid(request, new CurrentOrganizationUserIsNotPresentInRequest()); + } + + if (automaticUserConfirmationPolicyRequirement.IsEnabled(request.OrganizationId)) + { + if ((await providerUserRepository.GetManyByUserAsync(request.User.Id)).Count != 0) + { + return Invalid(request, new ProviderUsersCannotJoin()); + } + + if (request.AllOrganizationUsers.Count > 1) + { + return Invalid(request, new UserCannotBelongToAnotherOrganization()); + } + } + + if (automaticUserConfirmationPolicyRequirement.IsEnabledForOrganizationsOtherThan(currentOrganizationUser.OrganizationId)) + { + return Invalid(request, new OtherOrganizationDoesNotAllowOtherMembership()); + } + + return Valid(request); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs new file mode 100644 index 0000000000..7bc1664140 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs @@ -0,0 +1,28 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +/// +/// Used to enforce the Automatic User Confirmation policy. It uses the to retrieve +/// the . It is used to check to make sure the given user is +/// valid for the Automatic User Confirmation policy. It also validates that the given user is not a provider +/// or a member of another organization regardless of status or type. +/// +public interface IAutomaticUserConfirmationPolicyEnforcementValidator +{ + + /// + /// Checks if the given user is compliant with the Automatic User Confirmation policy. + /// + /// To be compliant, a user must + /// - not be a member of a provider + /// - not be a member of another organization + /// + /// + /// + /// This uses the validation result pattern to avoid throwing exceptions. + /// + /// A validation result with the error message if applicable. + Task> IsCompliantAsync(AutomaticUserConfirmationPolicyEnforcementRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs index e2bca930d1..57140317e3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs @@ -4,6 +4,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; @@ -16,19 +18,22 @@ public class SavePolicyCommand : ISavePolicyCommand private readonly IReadOnlyDictionary _policyValidators; private readonly TimeProvider _timeProvider; private readonly IPostSavePolicySideEffect _postSavePolicySideEffect; + private readonly IPushNotificationService _pushNotificationService; public SavePolicyCommand(IApplicationCacheService applicationCacheService, IEventService eventService, IPolicyRepository policyRepository, IEnumerable policyValidators, TimeProvider timeProvider, - IPostSavePolicySideEffect postSavePolicySideEffect) + IPostSavePolicySideEffect postSavePolicySideEffect, + IPushNotificationService pushNotificationService) { _applicationCacheService = applicationCacheService; _eventService = eventService; _policyRepository = policyRepository; _timeProvider = timeProvider; _postSavePolicySideEffect = postSavePolicySideEffect; + _pushNotificationService = pushNotificationService; var policyValidatorsDict = new Dictionary(); foreach (var policyValidator in policyValidators) @@ -75,6 +80,8 @@ public class SavePolicyCommand : ISavePolicyCommand await _policyRepository.UpsertAsync(policy); await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + await PushPolicyUpdateToClients(policy.OrganizationId, policy); + return policy; } @@ -152,4 +159,17 @@ public class SavePolicyCommand : ISavePolicyCommand var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdate.Type); return (savedPoliciesDict, currentPolicy); } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => this._pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs index 5d40cb211f..38e417d085 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs @@ -5,6 +5,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Int using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; @@ -15,7 +17,8 @@ public class VNextSavePolicyCommand( IPolicyRepository policyRepository, IEnumerable policyUpdateEventHandlers, TimeProvider timeProvider, - IPolicyEventHandlerFactory policyEventHandlerFactory) + IPolicyEventHandlerFactory policyEventHandlerFactory, + IPushNotificationService pushNotificationService) : IVNextSavePolicyCommand { @@ -74,7 +77,7 @@ public class VNextSavePolicyCommand( policy.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; await policyRepository.UpsertAsync(policy); - + await PushPolicyUpdateToClients(policyUpdateRequest.OrganizationId, policy); return policy; } @@ -192,4 +195,17 @@ public class VNextSavePolicyCommand( var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type); return savedPoliciesDict; } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs new file mode 100644 index 0000000000..3430f33a77 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs @@ -0,0 +1,48 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; + +/// +/// Represents the enforcement status of the Automatic User Confirmation policy. +/// +/// +/// The Automatic User Confirmation policy is enforced against all types of users regardless of status or type. +/// +/// Users cannot: +///
    +///
  • Be a member of another organization (similar to Single Organization Policy)
  • +///
  • Cannot be a provider
  • +///
+///
+/// Collection of policy details that apply to this user id +public class AutomaticUserConfirmationPolicyRequirement(IEnumerable policyDetails) : IPolicyRequirement +{ + public bool CannotBeGrantedEmergencyAccess() => policyDetails.Any(); + + public bool CannotJoinProvider() => policyDetails.Any(); + + public bool CannotCreateProvider() => policyDetails.Any(); + + public bool CannotCreateNewOrganization() => policyDetails.Any(); + + public bool IsEnabled(Guid organizationId) => policyDetails.Any(p => p.OrganizationId == organizationId); + + public bool IsEnabledForOrganizationsOtherThan(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId != organizationId); +} + +public class AutomaticUserConfirmationPolicyRequirementFactory : BasePolicyRequirementFactory +{ + public override PolicyType PolicyType => PolicyType.AutomaticUserConfirmation; + + protected override IEnumerable ExemptRoles => []; + + protected override IEnumerable ExemptStatuses => []; + + protected override bool ExemptProviders => false; + + public override AutomaticUserConfirmationPolicyRequirement Create(IEnumerable policyDetails) => + new(policyDetails); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs new file mode 100644 index 0000000000..d1e1efafd9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs @@ -0,0 +1,21 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; + +public class SingleOrganizationPolicyRequirement(IEnumerable policyDetails) : IPolicyRequirement +{ + public bool IsSingleOrgEnabledForThisOrganization(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId == organizationId); + + public bool IsSingleOrgEnabledForOrganizationsOtherThan(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId != organizationId); +} + +public class SingleOrganizationPolicyRequirementFactory : BasePolicyRequirementFactory +{ + public override PolicyType PolicyType => PolicyType.SingleOrg; + + public override SingleOrganizationPolicyRequirement Create(IEnumerable policyDetails) => + new(policyDetails); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 7c1987865a..f69935715d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; @@ -23,6 +24,8 @@ public static class PolicyServiceCollectionExtensions services.AddPolicyRequirements(); services.AddPolicySideEffects(); services.AddPolicyUpdateEvents(); + + services.AddScoped(); } [Obsolete("Use AddPolicyUpdateEvents instead.")] @@ -35,6 +38,8 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); + services.AddScoped(); } [Obsolete("Use AddPolicyUpdateEvents instead.")] @@ -53,6 +58,7 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); } @@ -65,5 +71,7 @@ public static class PolicyServiceCollectionExtensions services.AddScoped, RequireSsoPolicyRequirementFactory>(); services.AddScoped, RequireTwoFactorPolicyRequirementFactory>(); services.AddScoped, MasterPasswordPolicyRequirementFactory>(); + services.AddScoped, SingleOrganizationPolicyRequirementFactory>(); + services.AddScoped, AutomaticUserConfirmationPolicyRequirementFactory>(); } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs index c0d302df02..86c94147f4 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; @@ -17,26 +18,13 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; ///
  • All organization users are compliant with the Single organization policy
  • ///
  • No provider users exist
  • /// -/// -/// This class also performs side effects when the policy is being enabled or disabled. They are: -///
      -///
    • Sets the UseAutomaticUserConfirmation organization feature to match the policy update
    • -///
    ///
    public class AutomaticUserConfirmationPolicyEventHandler( IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IPolicyRepository policyRepository, - IOrganizationRepository organizationRepository, - TimeProvider timeProvider) - : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent, IEnforceDependentPoliciesEvent + IProviderUserRepository providerUserRepository) + : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent { public PolicyType Type => PolicyType.AutomaticUserConfirmation; - public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) => - await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); - - private const string _singleOrgPolicyNotEnabledErrorMessage = - "The Single organization policy must be enabled before enabling the Automatically confirm invited users policy."; private const string _usersNotCompliantWithSingleOrgErrorMessage = "All organization users must be compliant with the Single organization policy before enabling the Automatically confirm invited users policy. Please remove users who are members of multiple organizations."; @@ -61,27 +49,20 @@ public class AutomaticUserConfirmationPolicyEventHandler( public async Task ValidateAsync(SavePolicyModel savePolicyModel, Policy? currentPolicy) => await ValidateAsync(savePolicyModel.PolicyUpdate, currentPolicy); - public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) - { - var organization = await organizationRepository.GetByIdAsync(policyUpdate.OrganizationId); - - if (organization is not null) - { - organization.UseAutomaticUserConfirmation = policyUpdate.Enabled; - organization.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; - await organizationRepository.UpsertAsync(organization); - } - } + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => + Task.CompletedTask; private async Task ValidateEnablingPolicyAsync(Guid organizationId) { - var singleOrgValidationError = await ValidateSingleOrgPolicyComplianceAsync(organizationId); + var organizationUsers = await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + + var singleOrgValidationError = await ValidateUserComplianceWithSingleOrgAsync(organizationId, organizationUsers); if (!string.IsNullOrWhiteSpace(singleOrgValidationError)) { return singleOrgValidationError; } - var providerValidationError = await ValidateNoProviderUsersAsync(organizationId); + var providerValidationError = await ValidateNoProviderUsersAsync(organizationUsers); if (!string.IsNullOrWhiteSpace(providerValidationError)) { return providerValidationError; @@ -90,42 +71,24 @@ public class AutomaticUserConfirmationPolicyEventHandler( return string.Empty; } - private async Task ValidateSingleOrgPolicyComplianceAsync(Guid organizationId) + private async Task ValidateUserComplianceWithSingleOrgAsync(Guid organizationId, + ICollection organizationUsers) { - var singleOrgPolicy = await policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.SingleOrg); - if (singleOrgPolicy is not { Enabled: true }) - { - return _singleOrgPolicyNotEnabledErrorMessage; - } - - return await ValidateUserComplianceWithSingleOrgAsync(organizationId); - } - - private async Task ValidateUserComplianceWithSingleOrgAsync(Guid organizationId) - { - var organizationUsers = (await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId)) - .Where(ou => ou.Status != OrganizationUserStatusType.Invited && - ou.Status != OrganizationUserStatusType.Revoked && - ou.UserId.HasValue) - .ToList(); - - if (organizationUsers.Count == 0) - { - return string.Empty; - } - var hasNonCompliantUser = (await organizationUserRepository.GetManyByManyUsersAsync( organizationUsers.Select(ou => ou.UserId!.Value))) - .Any(uo => uo.OrganizationId != organizationId && - uo.Status != OrganizationUserStatusType.Invited); + .Any(uo => uo.OrganizationId != organizationId + && uo.Status != OrganizationUserStatusType.Invited); return hasNonCompliantUser ? _usersNotCompliantWithSingleOrgErrorMessage : string.Empty; } - private async Task ValidateNoProviderUsersAsync(Guid organizationId) + private async Task ValidateNoProviderUsersAsync(ICollection organizationUsers) { - var providerUsers = await providerUserRepository.GetManyByOrganizationAsync(organizationId); + var userIds = organizationUsers.Where(x => x.UserId is not null) + .Select(x => x.UserId!.Value); - return providerUsers.Count > 0 ? _providerUsersExistErrorMessage : string.Empty; + return (await providerUserRepository.GetManyByManyUsersAsync(userIds)).Count != 0 + ? _providerUsersExistErrorMessage + : string.Empty; } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs new file mode 100644 index 0000000000..92ba11f5a6 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs @@ -0,0 +1,59 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class BlockClaimedDomainAccountCreationPolicyValidator : IPolicyValidator, IPolicyValidationEvent +{ + private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; + private readonly IFeatureService _featureService; + + public BlockClaimedDomainAccountCreationPolicyValidator( + IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, + IFeatureService featureService) + { + _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _featureService = featureService; + } + + public PolicyType Type => PolicyType.BlockClaimedDomainAccountCreation; + + // No prerequisites - this policy stands alone + public IEnumerable RequiredPolicies => []; + + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + // Check if feature is enabled + if (!_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + return "This feature is not enabled"; + } + + // Only validate when trying to ENABLE the policy + if (policyUpdate is { Enabled: true }) + { + // Check if organization has at least one verified domain + if (!await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policyUpdate.OrganizationId)) + { + return "You must claim at least one domain to turn on this policy"; + } + } + + // Disabling the policy is always allowed + return string.Empty; + } + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + => Task.CompletedTask; +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs index c0378bf5f9..d24c61e258 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; @@ -29,8 +27,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent private readonly IOrganizationRepository _organizationRepository; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly ICurrentContext _currentContext; - private readonly IFeatureService _featureService; - private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand; @@ -40,8 +36,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent IOrganizationRepository organizationRepository, ISsoConfigRepository ssoConfigRepository, ICurrentContext currentContext, - IFeatureService featureService, - IRemoveOrganizationUserCommand removeOrganizationUserCommand, IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand) { @@ -50,8 +44,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent _organizationRepository = organizationRepository; _ssoConfigRepository = ssoConfigRepository; _currentContext = currentContext; - _featureService = featureService; - _removeOrganizationUserCommand = removeOrganizationUserCommand; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs index 0a774cf395..fb42ffa000 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs @@ -6,10 +6,23 @@ namespace Bit.Core.Repositories; public interface IOrganizationIntegrationConfigurationRepository : IRepository { - Task> GetConfigurationDetailsAsync( + /// + /// Retrieve the list of available configuration details for a specific event for the organization and + /// integration type.
    + ///
    + /// Note: This returns all configurations that match the event type explicitly and + /// all the configurations that have a null event type - null event type is considered a + /// wildcard that matches all events. + /// + ///
    + /// The specific event type + /// The id of the organization + /// The integration type + /// A List of that match + Task> GetManyByEventTypeOrganizationIdIntegrationType( + EventType eventType, Guid organizationId, - IntegrationType integrationType, - EventType eventType); + IntegrationType integrationType); Task> GetAllConfigurationDetailsAsync(); diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index b17de3c51d..41622c24b7 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.Entities; using Bit.Core.Enums; @@ -93,7 +94,18 @@ public interface IOrganizationUserRepository : IRepository - /// Accepted OrganizationUser to confirm + /// Accepted OrganizationUser to confirm /// True, if the user was updated. False, if not performed. - Task ConfirmOrganizationUserAsync(OrganizationUser organizationUser); + Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm); + + /// + /// Returns the OrganizationUserUserDetails if found. + /// + /// The id of the organization + /// The id of the User to fetch + /// OrganizationUserUserDetails of the specified user or null if not found + /// + /// Similar to GetByOrganizationAsync, but returns the user details. + /// + Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId); } diff --git a/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs b/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs index 7bc4125778..0a640b7530 100644 --- a/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs @@ -12,6 +12,7 @@ public interface IProviderUserRepository : IRepository Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); Task> GetManyAsync(IEnumerable ids); Task> GetManyByUserAsync(Guid userId); + Task> GetManyByManyUsersAsync(IEnumerable userIds); Task GetByProviderUserAsync(Guid providerId, Guid userId); Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); Task> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status = null); diff --git a/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs b/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..48346cbae7 --- /dev/null +++ b/src/Core/AdminConsole/Services/IOrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Services; + +public interface IOrganizationIntegrationConfigurationValidator +{ + /// + /// Validates that the configuration is valid for the given integration type. The configuration must + /// include a Configuration that is valid for the type, valid Filters, and a non-empty Template + /// to pass validation. + /// + /// The type of integration + /// The OrganizationIntegrationConfiguration to validate + /// True if valid, false otherwise + bool ValidateConfiguration(IntegrationType integrationType, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs b/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs index f81175f7b5..4f48b64b5a 100644 --- a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs +++ b/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs @@ -8,7 +8,7 @@ namespace Bit.Core.Services; public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService { public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Events.ConnectionString, "event"), + new QueueClient(globalSettings.Events.ConnectionString, globalSettings.Events.QueueName), JsonHelpers.IgnoreWritingNull) { } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs index 8423652eb8..b4246884f7 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs @@ -1,10 +1,16 @@ using System.Text.Json; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Core.Utilities; using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; namespace Bit.Core.Services; @@ -12,25 +18,17 @@ public class EventIntegrationHandler( IntegrationType integrationType, IEventIntegrationPublisher eventIntegrationPublisher, IIntegrationFilterService integrationFilterService, - IIntegrationConfigurationDetailsCache configurationCache, - IUserRepository userRepository, + IFusionCache cache, + IOrganizationIntegrationConfigurationRepository configurationRepository, + IGroupRepository groupRepository, IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, ILogger> logger) : IEventMessageHandler { public async Task HandleEventAsync(EventMessage eventMessage) { - if (eventMessage.OrganizationId is not Guid organizationId) - { - return; - } - - var configurations = configurationCache.GetConfigurationDetails( - organizationId, - integrationType, - eventMessage.Type); - - foreach (var configuration in configurations) + foreach (var configuration in await GetConfigurationDetailsListAsync(eventMessage)) { try { @@ -57,7 +55,7 @@ public class EventIntegrationHandler( { IntegrationType = integrationType, MessageId = messageId.ToString(), - OrganizationId = organizationId.ToString(), + OrganizationId = eventMessage.OrganizationId?.ToString(), Configuration = config, RenderedTemplate = renderedTemplate, RetryCount = 0, @@ -85,25 +83,83 @@ public class EventIntegrationHandler( } } - private async Task BuildContextAsync(EventMessage eventMessage, string template) + internal async Task BuildContextAsync(EventMessage eventMessage, string template) { + // Note: All of these cache calls use the default options, including TTL of 30 minutes + var context = new IntegrationTemplateContext(eventMessage); + if (IntegrationTemplateProcessor.TemplateRequiresGroup(template) && eventMessage.GroupId.HasValue) + { + context.Group = await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForGroup(eventMessage.GroupId.Value), + factory: async _ => await groupRepository.GetByIdAsync(eventMessage.GroupId.Value) + ); + } + + if (eventMessage.OrganizationId is not Guid organizationId) + { + return context; + } + if (IntegrationTemplateProcessor.TemplateRequiresUser(template) && eventMessage.UserId.HasValue) { - context.User = await userRepository.GetByIdAsync(eventMessage.UserId.Value); + context.User = await GetUserFromCacheAsync(organizationId, eventMessage.UserId.Value); } if (IntegrationTemplateProcessor.TemplateRequiresActingUser(template) && eventMessage.ActingUserId.HasValue) { - context.ActingUser = await userRepository.GetByIdAsync(eventMessage.ActingUserId.Value); + context.ActingUser = await GetUserFromCacheAsync(organizationId, eventMessage.ActingUserId.Value); } - if (IntegrationTemplateProcessor.TemplateRequiresOrganization(template) && eventMessage.OrganizationId.HasValue) + if (IntegrationTemplateProcessor.TemplateRequiresOrganization(template)) { - context.Organization = await organizationRepository.GetByIdAsync(eventMessage.OrganizationId.Value); + context.Organization = await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganization(organizationId), + factory: async _ => await organizationRepository.GetByIdAsync(organizationId) + ); } return context; } + + private async Task> GetConfigurationDetailsListAsync(EventMessage eventMessage) + { + if (eventMessage.OrganizationId is not Guid organizationId) + { + return []; + } + + List configurations = []; + + var integrationTag = EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integrationType + ); + + configurations.AddRange(await cache.GetOrSetAsync>( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integrationType, + eventType: eventMessage.Type), + factory: async _ => await configurationRepository.GetManyByEventTypeOrganizationIdIntegrationType( + eventType: eventMessage.Type, + organizationId: organizationId, + integrationType: integrationType), + options: new FusionCacheEntryOptions( + duration: EventIntegrationsCacheConstants.DurationForOrganizationIntegrationConfigurationDetails), + tags: [integrationTag] + )); + + return configurations; + } + + private async Task GetUserFromCacheAsync(Guid organizationId, Guid userId) => + await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationUser(organizationId, userId), + factory: async _ => await organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + organizationId: organizationId, + userId: userId + ) + ); } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs deleted file mode 100644 index a63efac62f..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs +++ /dev/null @@ -1,83 +0,0 @@ -using System.Diagnostics; -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.Services; - -public class IntegrationConfigurationDetailsCacheService : BackgroundService, IIntegrationConfigurationDetailsCache -{ - private readonly record struct IntegrationCacheKey(Guid OrganizationId, IntegrationType IntegrationType, EventType? EventType); - private readonly IOrganizationIntegrationConfigurationRepository _repository; - private readonly ILogger _logger; - private readonly TimeSpan _refreshInterval; - private Dictionary> _cache = new(); - - public IntegrationConfigurationDetailsCacheService( - IOrganizationIntegrationConfigurationRepository repository, - GlobalSettings globalSettings, - ILogger logger) - { - _repository = repository; - _logger = logger; - _refreshInterval = TimeSpan.FromMinutes(globalSettings.EventLogging.IntegrationCacheRefreshIntervalMinutes); - } - - public List GetConfigurationDetails( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) - { - var specificKey = new IntegrationCacheKey(organizationId, integrationType, eventType); - var allEventsKey = new IntegrationCacheKey(organizationId, integrationType, null); - - var results = new List(); - - if (_cache.TryGetValue(specificKey, out var specificConfigs)) - { - results.AddRange(specificConfigs); - } - if (_cache.TryGetValue(allEventsKey, out var fallbackConfigs)) - { - results.AddRange(fallbackConfigs); - } - - return results; - } - - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - await RefreshAsync(); - - var timer = new PeriodicTimer(_refreshInterval); - while (await timer.WaitForNextTickAsync(stoppingToken)) - { - await RefreshAsync(); - } - } - - internal async Task RefreshAsync() - { - var stopwatch = Stopwatch.StartNew(); - try - { - var newCache = (await _repository.GetAllConfigurationDetailsAsync()) - .GroupBy(x => new IntegrationCacheKey(x.OrganizationId, x.IntegrationType, x.EventType)) - .ToDictionary(g => g.Key, g => g.ToList()); - _cache = newCache; - - stopwatch.Stop(); - _logger.LogInformation( - "[IntegrationConfigurationDetailsCacheService] Refreshed successfully: {Count} entries in {Duration}ms", - newCache.Count, - stopwatch.Elapsed.TotalMilliseconds); - } - catch (Exception ex) - { - _logger.LogError("[IntegrationConfigurationDetailsCacheService] Refresh failed: {ex}", ex); - } - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md index 7570d47211..f9de5b9778 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md @@ -295,33 +295,60 @@ graph TD ``` ## Caching -To reduce database load and improve performance, integration configurations are cached in-memory as a Dictionary -with a periodic load of all configurations. Without caching, each incoming `EventMessage` would trigger a database +To reduce database load and improve performance, event integrations uses its own named extended cache (see +[CACHING in Utilities](https://github.com/bitwarden/server/blob/main/src/Core/Utilities/CACHING.md) +for more information). Without caching, for instance, each incoming `EventMessage` would trigger a database query to retrieve the relevant `OrganizationIntegrationConfigurationDetails`. -By loading all configurations into memory on a fixed interval, we ensure: +### `EventIntegrationsCacheConstants` -- Consistent performance for reads. -- Reduced database pressure. -- Predictable refresh timing, independent of event activity. +`EventIntegrationsCacheConstants` allows the code to have strongly typed references to a number of cache-related +details when working with the extended cache. The cache name and all cache keys and tags are programmatically accessed +from `EventIntegrationsCacheConstants` rather than simple strings. For instance, +`EventIntegrationsCacheConstants.CacheName` is used in the cache setup, keyed services, dependency injection, etc., +rather than using a string literal (i.e. "EventIntegrations") in code. -### Architecture / Design +### `OrganizationIntegrationConfigurationDetails` -- The cache is read-only for consumers. It is only updated in bulk by a background refresh process. -- The cache is fully replaced on each refresh to avoid locking or partial state. +- This is one of the most actively used portions of the architecture because any event that has an associated + organization requires a check of the configurations to determine if we need to fire off an integration. +- By using the extended cache, all reads are hitting the L1 or L2 cache before needing to access the database. - Reads return a `List` for a given key or an empty list if no match exists. -- Failures or delays in the loading process do not affect the existing cache state. The cache will continue serving - the last known good state until the update replaces the whole cache. +- The TTL is set very high on these records (1 day). This is because when the admin API makes any changes, it + tells the cache to remove that key. This propagates to the event listening code via the extended cache backplane, + which means that the cache is then expired and the next read will fetch the new values. This allows us to have + a high TTL and avoid needing to refresh values except when necessary. -### Background Refresh +#### Tagging per integration -A hosted service (`IntegrationConfigurationDetailsCacheService`) runs in the background and: +- Each entry in the cache (which again, returns `List`) is tagged with + the organization id and the integration type. +- This allows us to remove all of a given organization's configuration details for an integration when the admin + makes changes at the integration level. + - For instance, if there were 5 events configured for a given organization's webhook and the admin changed the URL + at the integration level, the updates would need to be propagated or else the cache will continue returning the + stale URL. + - By tagging each of the entries, the API can ask the extended cache to remove all the entries for a given + organization integration in one call. The cache will handle dropping / refreshing these entries in a + performant way. +- There are two places in the code that are both aware of the tagging functionality + - The `EventIntegrationHandler` must use the tag when fetching relevant configuration details. This tells the cache + to store the entry with the tag when it successfully loads from the repository. + - The `CreateOrganizationIntegrationCommand`, `UpdateOrganizationIntegrationCommand`, and + `DeleteOrganizationIntegrationCommand` commands need to use the tag to remove all the tagged entries when an admin + creates, updates, or deletes an integration. + - To ensure both places are synchronized on how to tag entries, they both use + `EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration` to build the tag. -- Loads all configuration records at application startup. -- Refreshes the cache on a configurable interval. -- Logs timing and entry count on success. -- Logs exceptions on failure without disrupting application flow. +### Template Properties + +- The `IntegrationTemplateProcessor` supports some properties that require an additional lookup. For instance, + the `UserId` is provided as part of the `EventMessage`, but `UserName` means an additional lookup to map the user + id to the actual name. +- The properties for a `User` (which includes `ActingUser`), `Group`, and `Organization` are cached via the + extended cache with a default TTL of 30 minutes. +- This is cached in both the L1 (Memory) and L2 (Redis) and will be automatically refreshed as needed. # Building a new integration diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 1b52ad8cff..e1fcbb970d 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -21,6 +21,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -47,7 +48,7 @@ public class OrganizationService : IOrganizationService private readonly IPushNotificationService _pushNotificationService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly ISsoUserRepository _ssoUserRepository; @@ -74,7 +75,7 @@ public class OrganizationService : IOrganizationService IPushNotificationService pushNotificationService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, ISsoUserRepository ssoUserRepository, @@ -148,7 +149,7 @@ public class OrganizationService : IOrganizationService } var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, - plan.PasswordManager.StripeStoragePlanId); + plan.PasswordManager.StripeStoragePlanId, plan.PasswordManager.BaseStorageGb); await ReplaceAndUpdateCacheAsync(organization); return secret; } @@ -358,7 +359,7 @@ public class OrganizationService : IOrganizationService { var newDisplayName = organization.DisplayName(); - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail, diff --git a/src/Core/AdminConsole/Services/OrganizationFactory.cs b/src/Core/AdminConsole/Services/OrganizationFactory.cs index f5df3327b1..0c64a27431 100644 --- a/src/Core/AdminConsole/Services/OrganizationFactory.cs +++ b/src/Core/AdminConsole/Services/OrganizationFactory.cs @@ -62,6 +62,7 @@ public static class OrganizationFactory UseAdminSponsoredFamilies = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAdminSponsoredFamilies), UseAutomaticUserConfirmation = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation), + UsePhishingBlocker = claimsPrincipal.GetValue(OrganizationLicenseConstants.UsePhishingBlocker), }; public static Organization Create( @@ -111,6 +112,7 @@ public static class OrganizationFactory UseRiskInsights = license.UseRiskInsights, UseOrganizationDomains = license.UseOrganizationDomains, UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies, - UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation, + UsePhishingBlocker = license.UsePhishingBlocker, }; } diff --git a/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs b/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..2769565675 --- /dev/null +++ b/src/Core/AdminConsole/Services/OrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,76 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Services; + +public class OrganizationIntegrationConfigurationValidator : IOrganizationIntegrationConfigurationValidator +{ + public bool ValidateConfiguration(IntegrationType integrationType, + OrganizationIntegrationConfiguration configuration) + { + // Validate template is present + if (string.IsNullOrWhiteSpace(configuration.Template)) + { + return false; + } + // If Filters are present, they must be valid + if (!IsFiltersValid(configuration.Filters)) + { + return false; + } + + switch (integrationType) + { + case IntegrationType.CloudBillingSync or IntegrationType.Scim: + return false; + case IntegrationType.Slack: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Webhook: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Hec: + case IntegrationType.Datadog: + case IntegrationType.Teams: + return configuration.Configuration is null; + default: + return false; + } + } + + private static bool IsConfigurationValid(string? configuration) + { + if (string.IsNullOrWhiteSpace(configuration)) + { + return false; + } + + try + { + var config = JsonSerializer.Deserialize(configuration); + return config is not null; + } + catch + { + return false; + } + } + + private static bool IsFiltersValid(string? filters) + { + if (filters is null) + { + return true; + } + + try + { + var filterGroup = JsonSerializer.Deserialize(filters); + return filterGroup is not null; + } + catch + { + return false; + } + } +} diff --git a/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs b/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs index b561e58a86..7fc8013c15 100644 --- a/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs +++ b/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Text.RegularExpressions; +using System.Text.RegularExpressions; namespace Bit.Core.AdminConsole.Utilities; @@ -26,7 +24,7 @@ public static partial class IntegrationTemplateProcessor return match.Value; // Return unknown keys as keys - i.e. #Key# } - return property?.GetValue(values)?.ToString() ?? ""; + return property.GetValue(values)?.ToString() ?? string.Empty; }); } @@ -38,7 +36,8 @@ public static partial class IntegrationTemplateProcessor } return template.Contains("#UserName#", StringComparison.Ordinal) - || template.Contains("#UserEmail#", StringComparison.Ordinal); + || template.Contains("#UserEmail#", StringComparison.Ordinal) + || template.Contains("#UserType#", StringComparison.Ordinal); } public static bool TemplateRequiresActingUser(string template) @@ -49,7 +48,18 @@ public static partial class IntegrationTemplateProcessor } return template.Contains("#ActingUserName#", StringComparison.Ordinal) - || template.Contains("#ActingUserEmail#", StringComparison.Ordinal); + || template.Contains("#ActingUserEmail#", StringComparison.Ordinal) + || template.Contains("#ActingUserType#", StringComparison.Ordinal); + } + + public static bool TemplateRequiresGroup(string template) + { + if (string.IsNullOrEmpty(template)) + { + return false; + } + + return template.Contains("#GroupName#", StringComparison.Ordinal); } public static bool TemplateRequiresOrganization(string template) diff --git a/src/Core/AdminConsole/Utilities/v2/Errors.cs b/src/Core/AdminConsole/Utilities/v2/Errors.cs new file mode 100644 index 0000000000..c1c66b2630 --- /dev/null +++ b/src/Core/AdminConsole/Utilities/v2/Errors.cs @@ -0,0 +1,15 @@ +namespace Bit.Core.AdminConsole.Utilities.v2; + +/// +/// A strongly typed error containing a reason that an action failed. +/// This is used for business logic validation and other expected errors, not exceptions. +/// +public abstract record Error(string Message); +/// +/// An type that maps to a NotFoundResult at the api layer. +/// +/// +public abstract record NotFoundError(string Message) : Error(Message); + +public abstract record BadRequestError(string Message) : Error(Message); +public abstract record InternalError(string Message) : Error(Message); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs b/src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs similarity index 94% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs rename to src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs index fbb00a908a..fb1bd16b2d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs +++ b/src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs @@ -1,7 +1,7 @@ using OneOf; using OneOf.Types; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +namespace Bit.Core.AdminConsole.Utilities.v2.Results; /// /// Represents the result of a command. @@ -39,4 +39,3 @@ public record BulkCommandResult(Guid Id, CommandResult Result); /// A wrapper for with an ID, to identify the result in bulk operations. /// public record BulkCommandResult(Guid Id, CommandResult Result); - diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs b/src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs similarity index 94% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs rename to src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs index c84a0aeda1..e28eac9a1c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs +++ b/src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs @@ -1,7 +1,7 @@ using OneOf; using OneOf.Types; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +namespace Bit.Core.AdminConsole.Utilities.v2.Validation; /// /// Represents the result of validating a request. diff --git a/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs b/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs new file mode 100644 index 0000000000..bcc4b851c0 --- /dev/null +++ b/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs @@ -0,0 +1,29 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.Auth.Models.Api.Request.Accounts; + +namespace Bit.Core.Auth.Attributes; + +public class MarketingInitiativeValidationAttribute : ValidationAttribute +{ + private static readonly string[] _acceptedValues = [MarketingInitiativeConstants.Premium]; + + public MarketingInitiativeValidationAttribute() + { + ErrorMessage = $"Marketing initiative type must be one of: {string.Join(", ", _acceptedValues)}"; + } + + public override bool IsValid(object? value) + { + if (value == null) + { + return true; + } + + if (value is not string str) + { + return false; + } + + return _acceptedValues.Contains(str); + } +} diff --git a/src/Core/Auth/Entities/AuthRequest.cs b/src/Core/Auth/Entities/AuthRequest.cs index 2117c575c0..38dc0534c1 100644 --- a/src/Core/Auth/Entities/AuthRequest.cs +++ b/src/Core/Auth/Entities/AuthRequest.cs @@ -49,11 +49,9 @@ public class AuthRequest : ITableObject public bool IsExpired() { - // TODO: PM-24252 - consider using TimeProvider for better mocking in tests return GetExpirationDate() < DateTime.UtcNow; } - // TODO: PM-24252 - this probably belongs in a service. public bool IsValidForAuthentication(Guid userId, string password) { diff --git a/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs b/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs deleted file mode 100644 index f8caad448b..0000000000 --- a/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs +++ /dev/null @@ -1,14 +0,0 @@ -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin; -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Auth.LoginFeatures; - -public static class LoginServiceCollectionExtensions -{ - public static void AddLoginServices(this IServiceCollection services) - { - services.AddScoped(); - } -} - diff --git a/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs b/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs deleted file mode 100644 index e5da1b06d8..0000000000 --- a/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; - -public interface IVerifyAuthRequestCommand -{ - Task VerifyAuthRequestAsync(Guid authRequestId, string accessCode); -} diff --git a/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs b/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs deleted file mode 100644 index 7def7fea76..0000000000 --- a/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs +++ /dev/null @@ -1,25 +0,0 @@ -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Utilities; - -namespace Bit.Core.Auth.LoginFeatures.PasswordlessLogin; - -public class VerifyAuthRequestCommand : IVerifyAuthRequestCommand -{ - private readonly IAuthRequestRepository _authRequestRepository; - - public VerifyAuthRequestCommand(IAuthRequestRepository authRequestRepository) - { - _authRequestRepository = authRequestRepository; - } - - public async Task VerifyAuthRequestAsync(Guid authRequestId, string accessCode) - { - var authRequest = await _authRequestRepository.GetByIdAsync(authRequestId); - if (authRequest == null || !CoreHelpers.FixedTimeEquals(authRequest.AccessCode, accessCode)) - { - return false; - } - return true; - } -} diff --git a/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs index f89b67f3c5..85ddef44ce 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs @@ -3,17 +3,22 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; public class KeysRequestModel { + [Obsolete("Use AccountKeys.AccountPublicKey instead")] [Required] public string PublicKey { get; set; } + [Obsolete("Use AccountKeys.UserKeyEncryptedAccountPrivateKey instead")] [Required] public string EncryptedPrivateKey { get; set; } + public AccountKeysRequestModel AccountKeys { get; set; } + [Obsolete("Use SetAccountKeysForUserCommand instead")] public User ToUser(User existingUser) { if (string.IsNullOrWhiteSpace(PublicKey) || string.IsNullOrWhiteSpace(EncryptedPrivateKey)) diff --git a/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs b/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs new file mode 100644 index 0000000000..ab2d252dc8 --- /dev/null +++ b/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs @@ -0,0 +1,10 @@ +namespace Bit.Core.Auth.Models.Api.Request.Accounts; + +public static class MarketingInitiativeConstants +{ + /// + /// Indicates that the user began the registration process on a marketing page designed + /// to streamline users who intend to setup a premium subscription after registration. + /// + public const string Premium = "premium"; +} diff --git a/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs index 75a4da081a..638565ecfe 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs @@ -1,5 +1,6 @@ #nullable enable using System.ComponentModel.DataAnnotations; +using Bit.Core.Auth.Attributes; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; @@ -11,4 +12,6 @@ public class RegisterSendVerificationEmailRequestModel [StringLength(256)] public required string Email { get; set; } public bool ReceiveMarketingEmails { get; set; } + [MarketingInitiativeValidation] + public string? FromMarketing { get; set; } } diff --git a/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs b/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs index f04a1181c4..5be7ed481f 100644 --- a/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs +++ b/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.Entities; using Bit.Core.Tokens; @@ -26,7 +23,7 @@ public class OrgUserInviteTokenable : ExpiringTokenable public string Identifier { get; set; } = TokenIdentifier; public Guid OrgUserId { get; set; } - public string OrgUserEmail { get; set; } + public string? OrgUserEmail { get; set; } [JsonConstructor] public OrgUserInviteTokenable() diff --git a/src/Core/Auth/Models/ITwoFactorProvidersUser.cs b/src/Core/Auth/Models/ITwoFactorProvidersUser.cs index 5cf137b76f..816d460572 100644 --- a/src/Core/Auth/Models/ITwoFactorProvidersUser.cs +++ b/src/Core/Auth/Models/ITwoFactorProvidersUser.cs @@ -1,14 +1,14 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Enums; using Bit.Core.Services; namespace Bit.Core.Auth.Models; +/// +/// An interface representing a user entity that supports two-factor providers +/// public interface ITwoFactorProvidersUser { - string TwoFactorProviders { get; } + string? TwoFactorProviders { get; } /// /// Get the two factor providers for the user. Currently it can be assumed providers are enabled /// if they exists in the dictionary. When two factor providers are disabled they are removed @@ -16,7 +16,10 @@ public interface ITwoFactorProvidersUser /// /// /// Dictionary of providers with the type enum as the key - Dictionary GetTwoFactorProviders(); + Dictionary? GetTwoFactorProviders(); + /// + /// The unique `UserId` of the user entity for which there are two-factor providers configured. + /// + /// The unique identifier for the user Guid? GetUserId(); - bool GetPremium(); } diff --git a/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs b/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs index fe42093111..5c0efeb73f 100644 --- a/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs +++ b/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs @@ -15,11 +15,13 @@ public class RegisterVerifyEmail : BaseMailModel // so we must land on a redirect connector which will redirect to the finish signup page. // Note 3: The use of a fragment to indicate the redirect url is to prevent the query string from being logged by // proxies and servers. It also helps reduce open redirect vulnerabilities. - public string Url => string.Format("{0}/redirect-connector.html#finish-signup?token={1}&email={2}&fromEmail=true", + public string Url => string.Format("{0}/redirect-connector.html#finish-signup?token={1}&email={2}&fromEmail=true{3}", WebVaultUrl, Token, - Email); + Email, + !string.IsNullOrEmpty(FromMarketing) ? $"&fromMarketing={FromMarketing}" : string.Empty); public string Token { get; set; } public string Email { get; set; } + public string FromMarketing { get; set; } } diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index 1a35585b2c..0cb8b68042 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -5,7 +5,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; @@ -26,8 +25,6 @@ public class SsoConfigService : ISsoConfigService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; - private readonly IFeatureService _featureService; - private readonly ISavePolicyCommand _savePolicyCommand; private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public SsoConfigService( @@ -36,8 +33,6 @@ public class SsoConfigService : ISsoConfigService IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IEventService eventService, - IFeatureService featureService, - ISavePolicyCommand savePolicyCommand, IVNextSavePolicyCommand vNextSavePolicyCommand) { _ssoConfigRepository = ssoConfigRepository; @@ -45,8 +40,6 @@ public class SsoConfigService : ISsoConfigService _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _eventService = eventService; - _featureService = featureService; - _savePolicyCommand = savePolicyCommand; _vNextSavePolicyCommand = vNextSavePolicyCommand; } @@ -97,19 +90,10 @@ public class SsoConfigService : ISsoConfigService Enabled = true }; - if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) - { - var performedBy = new SystemUser(EventSystemUser.Unknown); - await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(singleOrgPolicy, performedBy)); - await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(resetPasswordPolicy, performedBy)); - await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(requireSsoPolicy, performedBy)); - } - else - { - await _savePolicyCommand.SaveAsync(singleOrgPolicy); - await _savePolicyCommand.SaveAsync(resetPasswordPolicy); - await _savePolicyCommand.SaveAsync(requireSsoPolicy); - } + var performedBy = new SystemUser(EventSystemUser.Unknown); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(singleOrgPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(resetPasswordPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(requireSsoPolicy, performedBy)); } await LogEventsAsync(config, oldConfig); diff --git a/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c932eb0c34 --- /dev/null +++ b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,23 @@ +using Bit.Core.Entities; + +namespace Bit.Core.Auth.Sso; + +/// +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public interface IUserSsoOrganizationIdentifierQuery +{ + /// + /// Retrieves the SSO organization identifier for a confirmed organization user. + /// If there is more than one organization a User is associated with, we return null. If there are more than one + /// organization there is no way to know which organization the user wishes to authenticate with. + /// Owners and Admins who are not subject to the SSO required policy cannot utilize this flow, since they may have + /// multiple organizations with different SSO configurations. + /// + /// The ID of the to retrieve the SSO organization for. _Not_ an . + /// + /// The organization identifier if the user is a confirmed member of an organization with SSO configured, + /// otherwise null + /// + Task GetSsoOrganizationIdentifierAsync(Guid userId); +} diff --git a/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c0751e1f1a --- /dev/null +++ b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,38 @@ +using Bit.Core.Enums; +using Bit.Core.Repositories; + +namespace Bit.Core.Auth.Sso; + +/// +/// TODO : PM-28846 review data structures as they relate to this query +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public class UserSsoOrganizationIdentifierQuery( + IOrganizationUserRepository _organizationUserRepository, + IOrganizationRepository _organizationRepository) : IUserSsoOrganizationIdentifierQuery +{ + /// + public async Task GetSsoOrganizationIdentifierAsync(Guid userId) + { + // Get all confirmed organization memberships for the user + var organizationUsers = await _organizationUserRepository.GetManyByUserAsync(userId); + + // we can only confidently return the correct SsoOrganizationIdentifier if there is exactly one Organization. + // The user must also be in the Confirmed status. + var confirmedOrgUsers = organizationUsers.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed); + if (confirmedOrgUsers.Count() != 1) + { + return null; + } + + var confirmedOrgUser = confirmedOrgUsers.Single(); + var organization = await _organizationRepository.GetByIdAsync(confirmedOrgUser.OrganizationId); + + if (organization == null) + { + return null; + } + + return organization.Identifier; + } +} diff --git a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs index 62dd9dd293..97c2eabd3c 100644 --- a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs @@ -1,4 +1,5 @@ -using Bit.Core.Entities; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Entities; using Microsoft.AspNetCore.Identity; namespace Bit.Core.Auth.UserFeatures.Registration; @@ -14,6 +15,15 @@ public interface IRegisterUserCommand /// public Task RegisterUser(User user); + /// + /// Creates a new user, sends a welcome email, and raises the signup reference event. + /// This method is used by SSO auto-provisioned organization Users. + /// + /// The to create + /// The associated with the user + /// + Task RegisterSSOAutoProvisionedUserAsync(User user, Organization organization); + /// /// Creates a new user with a given master password hash, sends a welcome email (differs based on initiation path), /// and raises the signup reference event. Optionally accepts an org invite token and org user id to associate diff --git a/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs b/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs index b623b8cab3..2a224b9eb9 100644 --- a/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs @@ -3,5 +3,5 @@ namespace Bit.Core.Auth.UserFeatures.Registration; public interface ISendVerificationEmailForRegistrationCommand { - public Task Run(string email, string? name, bool receiveMarketingEmails); + public Task Run(string email, string? name, bool receiveMarketingEmails, string? fromMarketing); } diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs index 991be2b764..4a0e9c2cf5 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs @@ -1,11 +1,11 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -16,15 +16,20 @@ using Bit.Core.Tokens; using Bit.Core.Utilities; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Identity; +using Microsoft.Extensions.Logging; using Newtonsoft.Json; namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; public class RegisterUserCommand : IRegisterUserCommand { + private readonly ILogger _logger; private readonly IGlobalSettings _globalSettings; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationDomainRepository _organizationDomainRepository; + private readonly IFeatureService _featureService; private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IDataProtectorTokenFactory _registrationEmailVerificationTokenDataFactory; @@ -41,21 +46,28 @@ public class RegisterUserCommand : IRegisterUserCommand private readonly string _disabledUserRegistrationExceptionMsg = "Open registration has been disabled by the system administrator."; public RegisterUserCommand( - IGlobalSettings globalSettings, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - IDataProtectionProvider dataProtectionProvider, - IDataProtectorTokenFactory orgUserInviteTokenDataFactory, - IDataProtectorTokenFactory registrationEmailVerificationTokenDataFactory, - IUserService userService, - IMailService mailService, - IValidateRedemptionTokenCommand validateRedemptionTokenCommand, - IDataProtectorTokenFactory emergencyAccessInviteTokenDataFactory - ) + ILogger logger, + IGlobalSettings globalSettings, + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IPolicyRepository policyRepository, + IOrganizationDomainRepository organizationDomainRepository, + IFeatureService featureService, + IDataProtectionProvider dataProtectionProvider, + IDataProtectorTokenFactory orgUserInviteTokenDataFactory, + IDataProtectorTokenFactory registrationEmailVerificationTokenDataFactory, + IUserService userService, + IMailService mailService, + IValidateRedemptionTokenCommand validateRedemptionTokenCommand, + IDataProtectorTokenFactory emergencyAccessInviteTokenDataFactory) { + _logger = logger; _globalSettings = globalSettings; _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; _policyRepository = policyRepository; + _organizationDomainRepository = organizationDomainRepository; + _featureService = featureService; _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( "OrganizationServiceDataProtector"); @@ -69,11 +81,13 @@ public class RegisterUserCommand : IRegisterUserCommand _emergencyAccessInviteTokenDataFactory = emergencyAccessInviteTokenDataFactory; _providerServiceDataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + _featureService = featureService; } - public async Task RegisterUser(User user) { + await ValidateEmailDomainNotBlockedAsync(user.Email); + var result = await _userService.CreateUserAsync(user); if (result == IdentityResult.Success) { @@ -83,11 +97,30 @@ public class RegisterUserCommand : IRegisterUserCommand return result; } + public async Task RegisterSSOAutoProvisionedUserAsync(User user, Organization organization) + { + // Validate that the email domain is not blocked by another organization's policy + await ValidateEmailDomainNotBlockedAsync(user.Email, organization.Id); + + var result = await _userService.CreateUserAsync(user); + if (result == IdentityResult.Success) + { + await SendWelcomeEmailAsync(user, organization); + } + + return result; + } + public async Task RegisterUserViaOrganizationInviteToken(User user, string masterPasswordHash, string orgInviteToken, Guid? orgUserId) { - ValidateOrgInviteToken(orgInviteToken, orgUserId, user); - await SetUserEmail2FaIfOrgPolicyEnabledAsync(orgUserId, user); + TryValidateOrgInviteToken(orgInviteToken, orgUserId, user); + var orgUser = await SetUserEmail2FaIfOrgPolicyEnabledAsync(orgUserId, user); + if (orgUser == null && orgUserId.HasValue) + { + throw new BadRequestException("Invalid organization user invitation."); + } + await ValidateEmailDomainNotBlockedAsync(user.Email, orgUser?.OrganizationId); user.ApiKey = CoreHelpers.SecureRandomString(30); @@ -97,16 +130,17 @@ public class RegisterUserCommand : IRegisterUserCommand } var result = await _userService.CreateUserAsync(user, masterPasswordHash); + var organization = await GetOrganizationUserOrganization(orgUserId ?? Guid.Empty, orgUser); if (result == IdentityResult.Success) { var sentWelcomeEmail = false; if (!string.IsNullOrEmpty(user.ReferenceData)) { - var referenceData = JsonConvert.DeserializeObject>(user.ReferenceData); + var referenceData = JsonConvert.DeserializeObject>(user.ReferenceData) ?? []; if (referenceData.TryGetValue("initiationPath", out var value)) { - var initiationPath = value.ToString(); - await SendAppropriateWelcomeEmailAsync(user, initiationPath); + var initiationPath = value.ToString() ?? string.Empty; + await SendAppropriateWelcomeEmailAsync(user, initiationPath, organization); sentWelcomeEmail = true; if (!string.IsNullOrEmpty(initiationPath)) { @@ -117,14 +151,22 @@ public class RegisterUserCommand : IRegisterUserCommand if (!sentWelcomeEmail) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user, organization); } } return result; } - private void ValidateOrgInviteToken(string orgInviteToken, Guid? orgUserId, User user) + /// + /// This method attempts to validate the org invite token if provided. If the token is invalid an exception is thrown. + /// If there is no exception it is assumed the token is valid or not provided and open registration is allowed. + /// + /// The organization invite token. + /// The organization user ID. + /// The user being registered. + /// If validation fails then an exception is thrown. + private void TryValidateOrgInviteToken(string orgInviteToken, Guid? orgUserId, User user) { var orgInviteTokenProvided = !string.IsNullOrWhiteSpace(orgInviteToken); @@ -137,7 +179,6 @@ public class RegisterUserCommand : IRegisterUserCommand } // Token data is invalid - if (_globalSettings.DisableUserRegistration) { throw new BadRequestException(_disabledUserRegistrationExceptionMsg); @@ -147,7 +188,6 @@ public class RegisterUserCommand : IRegisterUserCommand } // no token data or missing token data - // Throw if open registration is disabled and there isn't an org invite token or an org user id // as you can't register without them. if (_globalSettings.DisableUserRegistration) @@ -171,12 +211,20 @@ public class RegisterUserCommand : IRegisterUserCommand // If both orgInviteToken && orgUserId are missing, then proceed with open registration } + /// + /// Validates the org invite token using the new tokenable logic first, then falls back to the old token validation logic for backwards compatibility. + /// Will set the out parameter organizationWelcomeEmailDetails if the new token is valid. If the token is invalid then no welcome email needs to be sent + /// so the out parameter is set to null. + /// + /// Invite token + /// Inviting Organization UserId + /// User email + /// true if the token is valid false otherwise private bool IsOrgInviteTokenValid(string orgInviteToken, Guid orgUserId, string userEmail) { // TODO: PM-4142 - remove old token validation logic once 3 releases of backwards compatibility are complete var newOrgInviteTokenValid = OrgUserInviteTokenable.ValidateOrgUserInviteStringToken( _orgUserInviteTokenDataFactory, orgInviteToken, orgUserId, userEmail); - return newOrgInviteTokenValid || CoreHelpers.UserInviteTokenIsValid( _organizationServiceDataProtector, orgInviteToken, userEmail, orgUserId, _globalSettings); } @@ -187,11 +235,12 @@ public class RegisterUserCommand : IRegisterUserCommand /// /// The optional org user id /// The newly created user object which could be modified - private async Task SetUserEmail2FaIfOrgPolicyEnabledAsync(Guid? orgUserId, User user) + /// The organization user if one exists for the provided org user id, null otherwise + private async Task SetUserEmail2FaIfOrgPolicyEnabledAsync(Guid? orgUserId, User user) { if (!orgUserId.HasValue) { - return; + return null; } var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); @@ -213,10 +262,11 @@ public class RegisterUserCommand : IRegisterUserCommand _userService.SetTwoFactorProvider(user, TwoFactorProviderType.Email); } } + return orgUser; } - private async Task SendAppropriateWelcomeEmailAsync(User user, string initiationPath) + private async Task SendAppropriateWelcomeEmailAsync(User user, string initiationPath, Organization? organization) { var isFromMarketingWebsite = initiationPath.Contains("Secrets Manager trial"); @@ -226,15 +276,15 @@ public class RegisterUserCommand : IRegisterUserCommand } else { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user, organization); } } public async Task RegisterUserViaEmailVerificationToken(User user, string masterPasswordHash, string emailVerificationToken) { - ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); var tokenable = ValidateRegistrationEmailVerificationTokenable(emailVerificationToken, user.Email); @@ -245,7 +295,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -255,6 +305,7 @@ public class RegisterUserCommand : IRegisterUserCommand string orgSponsoredFreeFamilyPlanInviteToken) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); await ValidateOrgSponsoredFreeFamilyPlanInviteToken(orgSponsoredFreeFamilyPlanInviteToken, user.Email); user.EmailVerified = true; @@ -263,7 +314,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -275,6 +326,7 @@ public class RegisterUserCommand : IRegisterUserCommand string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); ValidateAcceptEmergencyAccessInviteToken(acceptEmergencyAccessInviteToken, acceptEmergencyAccessId, user.Email); user.EmailVerified = true; @@ -283,7 +335,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -293,6 +345,7 @@ public class RegisterUserCommand : IRegisterUserCommand string providerInviteToken, Guid providerUserId) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); ValidateProviderInviteToken(providerInviteToken, providerUserId, user.Email); user.EmailVerified = true; @@ -301,7 +354,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -357,4 +410,79 @@ public class RegisterUserCommand : IRegisterUserCommand return tokenable; } + + private async Task ValidateEmailDomainNotBlockedAsync(string email, Guid? excludeOrganizationId = null) + { + // Only check if feature flag is enabled + if (!_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + return; + } + + var emailDomain = EmailValidation.GetDomain(email); + + var isDomainBlocked = await _organizationDomainRepository.HasVerifiedDomainWithBlockClaimedDomainPolicyAsync( + emailDomain, excludeOrganizationId); + if (isDomainBlocked) + { + _logger.LogInformation( + "User registration blocked by domain claim policy. Domain: {Domain}, ExcludedOrgId: {ExcludedOrgId}", + emailDomain, + excludeOrganizationId); + throw new BadRequestException("This email address is claimed by an organization using Bitwarden."); + } + } + + /// + /// We send different welcome emails depending on whether the user is joining a free/family or an enterprise organization. If information to populate the + /// email isn't present we send the standard individual welcome email. + /// + /// Target user for the email + /// this value is nullable + /// + private async Task SendWelcomeEmailAsync(User user, Organization? organization = null) + { + // Check if feature is enabled + // TODO: Remove Feature flag: PM-28221 + if (!_featureService.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)) + { + await _mailService.SendWelcomeEmailAsync(user); + return; + } + + // Most emails are probably for non organization users so we default to that experience + if (organization == null) + { + await _mailService.SendIndividualUserWelcomeEmailAsync(user); + } + // We need to make sure that the organization email has the correct data to display otherwise we just send the standard welcome email + else if (!string.IsNullOrEmpty(organization.DisplayName())) + { + // If the organization is Free or Families plan, send families welcome email + if (organization.PlanType.GetProductTier() is ProductTierType.Free or ProductTierType.Families) + { + await _mailService.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, organization.DisplayName()); + } + else + { + await _mailService.SendOrganizationUserWelcomeEmailAsync(user, organization.DisplayName()); + } + } + // If the organization data isn't present send the standard welcome email + else + { + await _mailService.SendIndividualUserWelcomeEmailAsync(user); + } + } + + private async Task GetOrganizationUserOrganization(Guid orgUserId, OrganizationUser? orgUser = null) + { + var organizationUser = orgUser ?? await _organizationUserRepository.GetByIdAsync(orgUserId); + if (organizationUser == null) + { + return null; + } + + return await _organizationRepository.GetByIdAsync(organizationUser.OrganizationId); + } } diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs index 3f89e9ad0e..2e8587eee6 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs @@ -5,6 +5,8 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Tokens; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; @@ -15,29 +17,34 @@ namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; /// public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmailForRegistrationCommand { - + private readonly ILogger _logger; private readonly IUserRepository _userRepository; private readonly GlobalSettings _globalSettings; private readonly IMailService _mailService; private readonly IDataProtectorTokenFactory _tokenDataFactory; private readonly IFeatureService _featureService; + private readonly IOrganizationDomainRepository _organizationDomainRepository; public SendVerificationEmailForRegistrationCommand( + ILogger logger, IUserRepository userRepository, GlobalSettings globalSettings, IMailService mailService, IDataProtectorTokenFactory tokenDataFactory, - IFeatureService featureService) + IFeatureService featureService, + IOrganizationDomainRepository organizationDomainRepository) { + _logger = logger; _userRepository = userRepository; _globalSettings = globalSettings; _mailService = mailService; _tokenDataFactory = tokenDataFactory; _featureService = featureService; + _organizationDomainRepository = organizationDomainRepository; } - public async Task Run(string email, string? name, bool receiveMarketingEmails) + public async Task Run(string email, string? name, bool receiveMarketingEmails, string? fromMarketing) { if (_globalSettings.DisableUserRegistration) { @@ -49,6 +56,20 @@ public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmai throw new ArgumentNullException(nameof(email)); } + // Check if the email domain is blocked by an organization policy + if (_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + var emailDomain = EmailValidation.GetDomain(email); + + if (await _organizationDomainRepository.HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(emailDomain)) + { + _logger.LogInformation( + "User registration email verification blocked by domain claim policy. Domain: {Domain}", + emailDomain); + throw new BadRequestException("This email address is claimed by an organization using Bitwarden."); + } + } + // Check to see if the user already exists var user = await _userRepository.GetByEmailAsync(email); var userExists = user != null; @@ -71,7 +92,7 @@ public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmai // If the user doesn't exist, create a new EmailVerificationTokenable and send the user // an email with a link to verify their email address var token = GenerateToken(email, name, receiveMarketingEmails); - await _mailService.SendRegistrationVerificationEmailAsync(email, token); + await _mailService.SendRegistrationVerificationEmailAsync(email, token, fromMarketing); } // User exists but we will return a 200 regardless of whether the email was sent or not; so return null diff --git a/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs b/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs index cc86d3d71d..e6c0c1444a 100644 --- a/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs +++ b/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs @@ -4,16 +4,37 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Premium.Queries; +using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Repositories; +using Bit.Core.Services; namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth; -public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery +public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery { - private readonly IUserRepository _userRepository = userRepository; + private readonly IUserRepository _userRepository; + private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery; + private readonly IFeatureService _featureService; + + public TwoFactorIsEnabledQuery( + IUserRepository userRepository, + IHasPremiumAccessQuery hasPremiumAccessQuery, + IFeatureService featureService) + { + _userRepository = userRepository; + _hasPremiumAccessQuery = hasPremiumAccessQuery; + _featureService = featureService; + } public async Task> TwoFactorIsEnabledAsync(IEnumerable userIds) { + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await TwoFactorIsEnabledVNextAsync(userIds); + } + var result = new List<(Guid userId, bool hasTwoFactor)>(); if (userIds == null || !userIds.Any()) { @@ -36,6 +57,11 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto public async Task> TwoFactorIsEnabledAsync(IEnumerable users) where T : ITwoFactorProvidersUser { + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await TwoFactorIsEnabledVNextAsync(users); + } + var userIds = users .Select(u => u.GetUserId()) .Where(u => u.HasValue) @@ -71,13 +97,134 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto return false; } + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + var userEntity = user as User ?? await _userRepository.GetByIdAsync(userId.Value); + if (userEntity == null) + { + throw new NotFoundException(); + } + + return await TwoFactorIsEnabledVNextAsync(userEntity); + } + return await TwoFactorEnabledAsync( - user.GetTwoFactorProviders(), - async () => - { - var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value); - return calcUser?.HasPremiumAccess ?? false; - }); + user.GetTwoFactorProviders(), + async () => + { + var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value); + return calcUser?.HasPremiumAccess ?? false; + }); + } + + private async Task> TwoFactorIsEnabledVNextAsync(IEnumerable userIds) + { + var result = new List<(Guid userId, bool hasTwoFactor)>(); + if (userIds == null || !userIds.Any()) + { + return result; + } + + var users = await _userRepository.GetManyAsync([.. userIds]); + + // Get enabled providers for each user + var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders); + + // Bulk fetch premium status only for users who need it (those with only premium providers) + var userIdsNeedingPremium = usersTwoFactorProvidersMap + .Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium)) + .Select(kvp => kvp.Key) + .ToList(); + + var premiumStatusMap = userIdsNeedingPremium.Count > 0 + ? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium) + : new Dictionary(); + + foreach (var user in users) + { + var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id]; + + if (!userTwoFactorProviders.Any()) + { + result.Add((user.Id, false)); + continue; + } + + // User has providers. If they're in the premium check map, verify premium status + var twoFactorIsEnabled = !premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium; + result.Add((user.Id, twoFactorIsEnabled)); + } + + return result; + } + + private async Task> TwoFactorIsEnabledVNextAsync(IEnumerable users) + where T : ITwoFactorProvidersUser + { + var userIds = users + .Select(u => u.GetUserId()) + .Where(u => u.HasValue) + .Select(u => u.Value) + .ToList(); + + var twoFactorResults = await TwoFactorIsEnabledVNextAsync(userIds); + + var result = new List<(T user, bool twoFactorIsEnabled)>(); + + foreach (var user in users) + { + var userId = user.GetUserId(); + if (userId.HasValue) + { + var hasTwoFactor = twoFactorResults.FirstOrDefault(res => res.userId == userId.Value).twoFactorIsEnabled; + result.Add((user, hasTwoFactor)); + } + else + { + result.Add((user, false)); + } + } + + return result; + } + + private async Task TwoFactorIsEnabledVNextAsync(User user) + { + var enabledProviders = GetEnabledTwoFactorProviders(user); + + if (!enabledProviders.Any()) + { + return false; + } + + // If all providers require premium, check if user has premium access + if (enabledProviders.All(TwoFactorProvider.RequiresPremium)) + { + return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id); + } + + // User has at least one non-premium provider + return true; + } + + /// + /// Gets all enabled two-factor provider types for a user. + /// + /// user with two factor providers + /// list of enabled provider types + private static IList GetEnabledTwoFactorProviders(User user) + { + var providers = user.GetTwoFactorProviders(); + + if (providers == null || providers.Count == 0) + { + return Array.Empty(); + } + + // TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into. + return (from provider in providers + where provider.Value?.Enabled ?? false + select provider.Key).ToList(); } /// diff --git a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs index 53bd8bdba2..7c50f7f17b 100644 --- a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs +++ b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ - - +using Bit.Core.Auth.Sso; using Bit.Core.Auth.UserFeatures.DeviceTrust; using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Auth.UserFeatures.Registration.Implementations; @@ -29,6 +28,7 @@ public static class UserServiceCollectionExtensions services.AddWebAuthnLoginCommands(); services.AddTdeOffboardingPasswordCommands(); services.AddTwoFactorQueries(); + services.AddSsoQueries(); } public static void AddDeviceTrustCommands(this IServiceCollection services) @@ -69,4 +69,9 @@ public static class UserServiceCollectionExtensions { services.AddScoped(); } + + private static void AddSsoQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 11f043fc69..dc128127ae 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -12,6 +12,12 @@ public static class StripeConstants public const string UnrecognizedLocation = "unrecognized_location"; } + public static class BillingReasons + { + public const string SubscriptionCreate = "subscription_create"; + public const string SubscriptionCycle = "subscription_cycle"; + } + public static class CollectionMethod { public const string ChargeAutomatically = "charge_automatically"; @@ -65,6 +71,7 @@ public static class StripeConstants public const string Region = "region"; public const string RetiredBraintreeCustomerId = "btCustomerId_old"; public const string UserId = "userId"; + public const string StorageReconciled2025 = "storage_reconciled_2025"; } public static class PaymentBehavior diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index d6593f5365..5ceefed603 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -6,6 +6,7 @@ using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Payment; using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Premium.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -31,6 +32,7 @@ public static class ServiceCollectionExtensions services.AddPaymentOperations(); services.AddOrganizationLicenseCommandsQueries(); services.AddPremiumCommands(); + services.AddPremiumQueries(); services.AddTransient(); services.AddTransient(); services.AddTransient(); @@ -50,4 +52,9 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddTransient(); } + + private static void AddPremiumQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs index 79ac94be62..727bcbc229 100644 --- a/src/Core/Billing/Licenses/LicenseConstants.cs +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -44,6 +44,7 @@ public static class OrganizationLicenseConstants public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies); public const string UseOrganizationDomains = nameof(UseOrganizationDomains); public const string UseAutomaticUserConfirmation = nameof(UseAutomaticUserConfirmation); + public const string UsePhishingBlocker = nameof(UsePhishingBlocker); } public static class UserLicenseConstants diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs index e9aadbe758..4a4771857e 100644 --- a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -26,7 +26,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory All { get; set; } = + [ + new() + { + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + SponsoredProductTierType = ProductTierType.Families, + SponsoringProductTierType = ProductTierType.Enterprise, + StripePlanId = "2021-family-for-enterprise-annually", + UsersCanSponsor = org => + org.PlanType.GetProductTier() == ProductTierType.Enterprise, + } + ]; + + public static SponsoredPlan Get(PlanSponsorshipType planSponsorshipType) => + All.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType)!; +} diff --git a/src/Core/Billing/Models/StaticStore/Plan.cs b/src/Core/Billing/Models/StaticStore/Plan.cs index 6d8d00089c..bab64d9879 100644 --- a/src/Core/Billing/Models/StaticStore/Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plan.cs @@ -97,7 +97,7 @@ public abstract record Plan public decimal PremiumAccessOptionPrice { get; init; } public short? MaxSeats { get; init; } // Storage - public short? BaseStorageGb { get; init; } + public short BaseStorageGb { get; init; } public bool HasAdditionalStorageOption { get; init; } public decimal AdditionalStoragePricePerGb { get; init; } public string StripeStoragePlanId { get; init; } diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs index 89d301c22a..2a5e786c98 100644 --- a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -3,12 +3,12 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; -using Bit.Core.Services; -using Bit.Core.Utilities; using Microsoft.Extensions.Logging; using OneOf; using Stripe; @@ -54,7 +54,7 @@ public class PreviewOrganizationTaxCommand( switch (purchase) { case { PasswordManager.Sponsored: true }: - var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise); + var sponsoredPlan = SponsoredPlans.Get(PlanSponsorshipType.FamiliesForEnterprise); items.Add(new InvoiceSubscriptionDetailsItemOptions { Price = sponsoredPlan.StripePlanId, @@ -125,7 +125,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); @@ -165,7 +165,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } else @@ -181,7 +181,7 @@ public class PreviewOrganizationTaxCommand( var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer"] }); if (subscription.Customer.Discount != null) @@ -259,7 +259,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } }); @@ -278,7 +278,7 @@ public class PreviewOrganizationTaxCommand( return new BadRequest("Organization does not have a subscription."); } - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.tax_ids"] }); var options = GetBaseOptions(subscription.Customer, @@ -336,7 +336,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); diff --git a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs index 7ccbacc938..584021f22f 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs @@ -143,6 +143,7 @@ public class OrganizationLicense : ILicense public int? SmSeats { get; set; } public int? SmServiceAccounts { get; set; } public bool UseRiskInsights { get; set; } + public bool UsePhishingBlocker { get; set; } // Deprecated. Left for backwards compatibility with old license versions. public bool LimitCollectionCreationDeletion { get; set; } = true; @@ -228,7 +229,8 @@ public class OrganizationLicense : ILicense !p.Name.Equals(nameof(UseRiskInsights)) && !p.Name.Equals(nameof(UseAdminSponsoredFamilies)) && !p.Name.Equals(nameof(UseOrganizationDomains)) && - !p.Name.Equals(nameof(UseAutomaticUserConfirmation))) + !p.Name.Equals(nameof(UseAutomaticUserConfirmation)) && + !p.Name.Equals(nameof(UsePhishingBlocker))) .OrderBy(p => p.Name) .Select(p => $"{p.Name}:{Core.Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Aggregate((c, n) => $"{c}|{n}"); @@ -399,7 +401,6 @@ public class OrganizationLicense : ILicense var installationId = claimsPrincipal.GetValue(nameof(InstallationId)); var licenseKey = claimsPrincipal.GetValue(nameof(LicenseKey)); var enabled = claimsPrincipal.GetValue(nameof(Enabled)); - var planType = claimsPrincipal.GetValue(nameof(PlanType)); var seats = claimsPrincipal.GetValue(nameof(Seats)); var maxCollections = claimsPrincipal.GetValue(nameof(MaxCollections)); var useGroups = claimsPrincipal.GetValue(nameof(UseGroups)); @@ -425,12 +426,18 @@ public class OrganizationLicense : ILicense var useOrganizationDomains = claimsPrincipal.GetValue(nameof(UseOrganizationDomains)); var useAutomaticUserConfirmation = claimsPrincipal.GetValue(nameof(UseAutomaticUserConfirmation)); + var claimedPlanType = claimsPrincipal.GetValue(nameof(PlanType)); + + var planTypesMatch = claimedPlanType == PlanType.FamiliesAnnually + ? organization.PlanType is PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 + : organization.PlanType == claimedPlanType; + return issued <= DateTime.UtcNow && expires >= DateTime.UtcNow && installationId == globalSettings.Installation.Id && licenseKey == organization.LicenseKey && enabled == organization.Enabled && - planType == organization.PlanType && + planTypesMatch && seats == organization.Seats && maxCollections == organization.MaxCollections && useGroups == organization.UseGroups && diff --git a/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs b/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs index ee603c67e0..6c1362d1c5 100644 --- a/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs +++ b/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core.Billing.Models; using Bit.Core.Models.Business; using Stripe; @@ -17,7 +18,7 @@ public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate { _existingPlanStripeId = existingPlan.PasswordManager.StripePlanId; _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId - ?? Core.Utilities.StaticStore.SponsoredPlans.FirstOrDefault()?.StripePlanId; + ?? SponsoredPlans.All.FirstOrDefault()?.StripePlanId; _applySponsorship = applySponsorship; } diff --git a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs index f00bc00356..a8a236decc 100644 --- a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs @@ -22,14 +22,14 @@ public interface IGetCloudOrganizationLicenseQuery public class GetCloudOrganizationLicenseQuery : IGetCloudOrganizationLicenseQuery { private readonly IInstallationRepository _installationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ILicensingService _licensingService; private readonly IProviderRepository _providerRepository; private readonly IFeatureService _featureService; public GetCloudOrganizationLicenseQuery( IInstallationRepository installationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ILicensingService licensingService, IProviderRepository providerRepository, IFeatureService featureService) diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs index 01e520ea41..af8dfa7aec 100644 --- a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs @@ -9,7 +9,6 @@ using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -201,7 +200,7 @@ public class GetOrganizationWarningsQuery( // ReSharper disable once InvertIf if (subscription.Status == SubscriptionStatus.PastDue) { - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -257,8 +256,8 @@ public class GetOrganizationWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer diff --git a/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs index d34bd86e7b..39d2a789e6 100644 --- a/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs @@ -56,4 +56,11 @@ public interface IOrganizationBillingService /// Thrown when the is . /// Thrown when no payment method is found for the customer, no plan IDs are provided, or subscription update fails. Task UpdateSubscriptionPlanFrequency(Organization organization, PlanType newPlanType); + + /// + /// Updates the organization name and email on the Stripe customer entry. + /// This only updates Stripe, not the Bitwarden database. + /// + /// The organization to update in Stripe. + Task UpdateOrganizationNameAndEmail(Organization organization); } diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index b10f04d766..a1b57c2415 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -14,7 +14,6 @@ using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -161,7 +160,7 @@ public class OrganizationBillingService( try { // Update the subscription in Stripe - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, updateOptions); organization.PlanType = newPlan.Type; await organizationRepository.ReplaceAsync(organization); } @@ -176,6 +175,45 @@ public class OrganizationBillingService( } } + public async Task UpdateOrganizationNameAndEmail(Organization organization) + { + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + logger.LogWarning( + "Organization ({OrganizationId}) has no Stripe customer to update", + organization.Id); + return; + } + + var newDisplayName = organization.DisplayName(); + + // Organization.DisplayName() can return null - handle gracefully + if (string.IsNullOrWhiteSpace(newDisplayName)) + { + logger.LogWarning( + "Organization ({OrganizationId}) has no name to update in Stripe", + organization.Id); + return; + } + + await stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, + new CustomerUpdateOptions + { + Email = organization.BillingEmail, + Description = newDisplayName, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + // This overwrites the existing custom fields for this organization + CustomFields = [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = organization.SubscriberType(), + Value = newDisplayName + }] + }, + }); + } + #region Utilities private async Task CreateCustomerAsync( @@ -295,7 +333,7 @@ public class OrganizationBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -329,7 +367,7 @@ public class OrganizationBillingService( try { - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); organization.Gateway = GatewayType.Stripe; organization.GatewayCustomerId = customer.Id; @@ -480,7 +518,7 @@ public class OrganizationBillingService( subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; await organizationRepository.ReplaceAsync(organization); @@ -508,14 +546,14 @@ public class OrganizationBillingService( customer = customer switch { { Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, TaxExempt = StripeConstants.TaxExempt.Reverse }), { Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, @@ -574,7 +612,7 @@ public class OrganizationBillingService( } } }; - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, options); + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, options); } } diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index f4eca40cae..daf39fb981 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Microsoft.Extensions.Logging; using Stripe; @@ -46,7 +45,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -71,7 +70,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -92,7 +91,7 @@ public class UpdateBillingAddressCommand( await EnableAutomaticTaxAsync(subscriber, customer); var deleteExistingTaxIds = customer.TaxIds?.Any() ?? false - ? customer.TaxIds.Select(taxId => stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id)).ToList() + ? customer.TaxIds.Select(taxId => stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id)).ToList() : []; if (billingAddress.TaxId == null) @@ -101,12 +100,12 @@ public class UpdateBillingAddressCommand( return BillingAddress.From(customer.Address); } - var updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + var updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }); if (billingAddress.TaxId.Code == StripeConstants.TaxIdType.SpanishNIF) { - updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, @@ -130,7 +129,7 @@ public class UpdateBillingAddressCommand( if (subscription is { AutomaticTax.Enabled: false }) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs index 81206b8032..a5a9e3e9c9 100644 --- a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -56,7 +55,7 @@ public class UpdatePaymentMethodCommand( if (billingAddress != null && customer.Address is not { Country: not null, PostalCode: not null }) { - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions @@ -75,7 +74,7 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var setupIntents = await stripeAdapter.SetupIntentList(new SetupIntentListOptions + var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { Expand = ["data.payment_method"], PaymentMethod = token @@ -104,9 +103,9 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); + var paymentMethod = await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { DefaultPaymentMethod = token } @@ -139,7 +138,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomer.Id }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } var payPalAccount = braintreeCustomer.DefaultPaymentMethod as PayPalAccount; @@ -204,7 +203,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = string.Empty }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } } } diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs index 9f9618571e..e03a785278 100644 --- a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; using Stripe; @@ -53,7 +52,7 @@ public class GetPaymentMethodQuery( if (!string.IsNullOrEmpty(setupIntentId)) { - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); diff --git a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs index ec77ee0712..c972c3fe5f 100644 --- a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Stripe; namespace Bit.Core.Billing.Payment.Queries; @@ -48,7 +47,7 @@ public class HasPaymentMethodQuery( return false; } - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 1f752a007b..ed60e2f11c 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -80,6 +80,8 @@ public class CreatePremiumCloudHostedSubscriptionCommand( return new BadRequest("Additional storage must be greater than 0."); } + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + Customer? customer; /* @@ -107,7 +109,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( customer = await ReconcileBillingLocationAsync(customer, billingAddress); - var subscription = await CreateSubscriptionAsync(user.Id, customer, additionalStorageGb > 0 ? additionalStorageGb : null); + var subscription = await CreateSubscriptionAsync(user.Id, customer, premiumPlan, additionalStorageGb > 0 ? additionalStorageGb : null); paymentMethod.Switch( tokenized => @@ -140,7 +142,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; user.GatewaySubscriptionId = subscription.Id; - user.MaxStorageGb = (short)(1 + additionalStorageGb); + user.MaxStorageGb = (short)(premiumPlan.Storage.Provided + additionalStorageGb); user.LicenseKey = CoreHelpers.SecureRandomString(20); user.RevisionDate = DateTime.UtcNow; @@ -208,7 +210,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -241,7 +243,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch { @@ -298,15 +300,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( ValidateLocation = ValidateTaxLocationTiming.Immediately } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } private async Task CreateSubscriptionAsync( Guid userId, Customer customer, + Pricing.Premium.Plan premiumPlan, int? storage) { - var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); var subscriptionItemOptionsList = new List { @@ -347,11 +349,11 @@ public class CreatePremiumCloudHostedSubscriptionCommand( OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs index 5f09b8b77b..07247c83cb 100644 --- a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using Stripe; @@ -56,7 +56,7 @@ public class PreviewPremiumTaxCommand( }); } - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); diff --git a/src/Core/Billing/Premium/Models/UserPremiumAccess.cs b/src/Core/Billing/Premium/Models/UserPremiumAccess.cs new file mode 100644 index 0000000000..639d175d25 --- /dev/null +++ b/src/Core/Billing/Premium/Models/UserPremiumAccess.cs @@ -0,0 +1,29 @@ +namespace Bit.Core.Billing.Premium.Models; + +/// +/// Represents user premium access status from personal subscriptions and organization memberships. +/// +public class UserPremiumAccess +{ + /// + /// The unique identifier for the user. + /// + public Guid Id { get; set; } + + /// + /// Indicates whether the user has a personal premium subscription. + /// This does NOT include premium access from organizations. + /// + public bool PersonalPremium { get; set; } + + /// + /// Indicates whether the user has premium access through any organization membership. + /// This is true if the user is a member of at least one enabled organization that grants premium access to users. + /// + public bool OrganizationPremium { get; set; } + + /// + /// Indicates whether the user has premium access from any source (personal subscription or organization). + /// + public bool HasPremiumAccess => PersonalPremium || OrganizationPremium; +} diff --git a/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs b/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs new file mode 100644 index 0000000000..e90710a9b3 --- /dev/null +++ b/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs @@ -0,0 +1,49 @@ +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Core.Billing.Premium.Queries; + +public class HasPremiumAccessQuery : IHasPremiumAccessQuery +{ + private readonly IUserRepository _userRepository; + + public HasPremiumAccessQuery(IUserRepository userRepository) + { + _userRepository = userRepository; + } + + public async Task HasPremiumAccessAsync(Guid userId) + { + var user = await _userRepository.GetPremiumAccessAsync(userId); + if (user == null) + { + throw new NotFoundException(); + } + + return user.HasPremiumAccess; + } + + public async Task> HasPremiumAccessAsync(IEnumerable userIds) + { + var distinctUserIds = userIds.Distinct().ToList(); + var usersWithPremium = await _userRepository.GetPremiumAccessByIdsAsync(distinctUserIds); + + if (usersWithPremium.Count() != distinctUserIds.Count) + { + throw new NotFoundException(); + } + + return usersWithPremium.ToDictionary(u => u.Id, u => u.HasPremiumAccess); + } + + public async Task HasPremiumFromOrganizationAsync(Guid userId) + { + var user = await _userRepository.GetPremiumAccessAsync(userId); + if (user == null) + { + throw new NotFoundException(); + } + + return user.OrganizationPremium; + } +} diff --git a/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs b/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs new file mode 100644 index 0000000000..e5545b1ade --- /dev/null +++ b/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs @@ -0,0 +1,30 @@ +namespace Bit.Core.Billing.Premium.Queries; + +/// +/// Centralized query for checking if users have premium access through personal subscriptions or organizations. +/// Note: Different from User.Premium which only checks personal subscriptions. +/// +public interface IHasPremiumAccessQuery +{ + /// + /// Checks if a user has premium access (personal or organization). + /// + /// The user ID to check + /// True if user can access premium features + Task HasPremiumAccessAsync(Guid userId); + + /// + /// Checks premium access for multiple users. + /// + /// The user IDs to check + /// Dictionary mapping user IDs to their premium access status + Task> HasPremiumAccessAsync(IEnumerable userIds); + + /// + /// Checks if a user belongs to any organization that grants premium (enabled org with UsersGetPremium). + /// Returns true regardless of personal subscription. Useful for UI decisions like showing subscription options. + /// + /// The user ID to check + /// True if user is in any organization that grants premium + Task HasPremiumFromOrganizationAsync(Guid userId); +} diff --git a/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs index 37dc63cb47..42090a56ca 100644 --- a/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs +++ b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs @@ -99,7 +99,7 @@ public record PlanAdapter : Core.Models.StaticStore.Plan _ => true); var baseSeats = GetBaseSeats(plan.Seats); var maxSeats = GetMaxSeats(plan.Seats); - var baseStorageGb = (short?)plan.Storage?.Provided; + var baseStorageGb = (short)(plan.Storage?.Provided ?? 0); var hasAdditionalStorageOption = plan.Storage != null; var additionalStoragePricePerGb = plan.Storage?.Price ?? 0; var stripeStoragePlanId = plan.Storage?.StripePriceId; diff --git a/src/Core/Billing/Pricing/Premium/Purchasable.cs b/src/Core/Billing/Pricing/Premium/Purchasable.cs index 633eb2e8aa..6bf69d9593 100644 --- a/src/Core/Billing/Pricing/Premium/Purchasable.cs +++ b/src/Core/Billing/Pricing/Premium/Purchasable.cs @@ -4,4 +4,5 @@ public class Purchasable { public string StripePriceId { get; init; } = null!; public decimal Price { get; init; } + public int Provided { get; init; } } diff --git a/src/Core/Billing/Pricing/PricingClient.cs b/src/Core/Billing/Pricing/PricingClient.cs index 1ec44c6496..ecb85ed7e8 100644 --- a/src/Core/Billing/Pricing/PricingClient.cs +++ b/src/Core/Billing/Pricing/PricingClient.cs @@ -6,7 +6,6 @@ using Bit.Core.Billing.Pricing.Organizations; using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; using Microsoft.Extensions.Logging; namespace Bit.Core.Billing.Pricing; @@ -28,13 +27,6 @@ public class PricingClient( return null; } - var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); - - if (!usePricingService) - { - return StaticStore.GetPlan(planType); - } - var lookupKey = GetLookupKey(planType); if (lookupKey == null) @@ -77,13 +69,6 @@ public class PricingClient( return []; } - var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); - - if (!usePricingService) - { - return StaticStore.Plans.ToList(); - } - var response = await httpClient.GetAsync("plans/organization"); if (response.IsSuccessStatusCode) @@ -114,11 +99,10 @@ public class PricingClient( return []; } - var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); var fetchPremiumPriceFromPricingService = featureService.IsEnabled(FeatureFlagKeys.PM26793_FetchPremiumPriceFromPricingService); - if (!usePricingService || !fetchPremiumPriceFromPricingService) + if (!fetchPremiumPriceFromPricingService) { return [CurrentPremiumPlan]; } @@ -186,6 +170,6 @@ public class PricingClient( Available = true, LegacyYear = null, Seat = new Purchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, - Storage = new Purchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal } + Storage = new Purchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } }; } diff --git a/src/Core/Billing/Providers/Services/IProviderBillingService.cs b/src/Core/Billing/Providers/Services/IProviderBillingService.cs index 57d68db038..3f5a48e817 100644 --- a/src/Core/Billing/Providers/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Providers/Services/IProviderBillingService.cs @@ -113,4 +113,11 @@ public interface IProviderBillingService TaxInformation taxInformation); Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command); + + /// + /// Updates the provider name and email on the Stripe customer entry. + /// This only updates Stripe, not the Bitwarden database. + /// + /// The provider to update in Stripe. + Task UpdateProviderNameAndEmail(Provider provider); } diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs new file mode 100644 index 0000000000..5ec732920e --- /dev/null +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -0,0 +1,50 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; + +namespace Bit.Core.Billing.Services; + +public interface IStripeAdapter +{ + Task CreateCustomerAsync(CustomerCreateOptions customerCreateOptions); + Task GetCustomerAsync(string id, CustomerGetOptions options = null); + Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null); + Task DeleteCustomerAsync(string id); + Task> ListCustomerPaymentMethodsAsync(string id, CustomerPaymentMethodListOptions options = null); + Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options); + Task CreateSubscriptionAsync(SubscriptionCreateOptions subscriptionCreateOptions); + Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null); + Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null); + Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); + Task UpdateSubscriptionAsync(string id, SubscriptionUpdateOptions options = null); + Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); + Task GetInvoiceAsync(string id, InvoiceGetOptions options); + Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); + Task> SearchInvoiceAsync(InvoiceSearchOptions options); + Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); + Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); + Task SendInvoiceAsync(string id, InvoiceSendOptions options); + Task PayInvoiceAsync(string id, InvoicePayOptions options = null); + Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null); + Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null); + IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options); + IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options); + Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null); + Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null); + Task CreateTaxIdAsync(string id, TaxIdCreateOptions options); + Task DeleteTaxIdAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); + Task> ListChargesAsync(ChargeListOptions options); + Task CreateRefundAsync(RefundCreateOptions options); + Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null); + Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); + Task CreateSetupIntentAsync(SetupIntentCreateOptions options); + Task> ListSetupIntentsAsync(SetupIntentListOptions options); + Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null); + Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null); + Task GetPriceAsync(string id, PriceGetOptions options = null); +} diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Billing/Services/IStripePaymentService.cs similarity index 85% rename from src/Core/Services/IPaymentService.cs rename to src/Core/Billing/Services/IStripePaymentService.cs index e7e848bcba..b948cf6921 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Billing/Services/IStripePaymentService.cs @@ -4,15 +4,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Business; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Billing.Tax.Responses; using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services; -public interface IPaymentService +public interface IStripePaymentService { Task CancelAndRecoverChargesAsync(ISubscriber subscriber); Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); @@ -44,8 +42,6 @@ public interface IPaymentService Task GetBillingAsync(ISubscriber subscriber); Task GetBillingHistoryAsync(ISubscriber subscriber); Task GetSubscriptionAsync(ISubscriber subscriber); - Task GetTaxInfoAsync(ISubscriber subscriber); - Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); Task AddSecretsManagerToSubscription(Organization org, Plan plan, int additionalSmSeats, int additionalServiceAccount); /// /// Secrets Manager Standalone is a discount in Stripe that is used to give an organization access to Secrets Manager. @@ -68,7 +64,4 @@ public interface IPaymentService /// Organization Representation used for Inviting Organization Users /// If the organization has Secrets Manager and has the Standalone Stripe Discount Task HasSecretsManagerStandalone(InviteOrganization organization); - Task PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - Task PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - } diff --git a/src/Core/Billing/Services/IStripeSyncService.cs b/src/Core/Billing/Services/IStripeSyncService.cs new file mode 100644 index 0000000000..b56204cd47 --- /dev/null +++ b/src/Core/Billing/Services/IStripeSyncService.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Billing.Services; + +public interface IStripeSyncService +{ + Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress); +} diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index f88727f37b..343a0e4f38 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -6,7 +6,6 @@ using Bit.Core.Billing.Tax.Models; using Bit.Core.Entities; using Bit.Core.Enums; using Stripe; -using PaymentMethod = Bit.Core.Billing.Models.PaymentMethod; namespace Bit.Core.Billing.Services; @@ -64,16 +63,6 @@ public interface ISubscriberService ISubscriber subscriber, CustomerGetOptions customerGetOptions = null); - /// - /// Retrieves the account credit, a masked representation of the default payment source and the tax information for the - /// provided . This is essentially a consolidated invocation of the - /// and methods with a response that includes the customer's as account credit in order to cut down on Stripe API calls. - /// - /// The subscriber to retrieve payment method for. - /// A containing the subscriber's account credit, payment source and tax information. - Task GetPaymentMethod( - ISubscriber subscriber); - /// /// Retrieves a masked representation of the subscriber's payment source for presentation to a client. /// @@ -107,16 +96,6 @@ public interface ISubscriberService ISubscriber subscriber, SubscriptionGetOptions subscriptionGetOptions = null); - /// - /// Retrieves the 's tax information using their Stripe 's . - /// - /// The subscriber to retrieve the tax information for. - /// A representing the 's tax information. - /// Thrown when the is . - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetTaxInformation( - ISubscriber subscriber); - /// /// Attempts to remove a subscriber's saved payment source. If the Stripe representing the /// contains a valid "btCustomerId" key in its property, @@ -147,17 +126,6 @@ public interface ISubscriberService ISubscriber subscriber, TaxInformation taxInformation); - /// - /// Verifies the subscriber's pending bank account using the provided . - /// - /// The subscriber to verify the bank account for. - /// The code attached to a deposit made to the subscriber's bank account in order to ensure they have access to it. - /// Learn more. - /// - Task VerifyBankAccount( - ISubscriber subscriber, - string descriptorCode); - /// /// Validates whether the 's exists in the gateway. /// If the 's is or empty, returns . diff --git a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs index 5a8cf16f5a..16b3f7e0c3 100644 --- a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs +++ b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.Billing.Services.Implementations; @@ -23,7 +22,7 @@ public class PaymentHistoryService( return Array.Empty(); } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = subscriber.GatewayCustomerId, Limit = pageSize, diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 3170060de4..9c85971dff 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -12,7 +12,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -68,7 +67,7 @@ public class PremiumUserBillingService( } }; - customer = await stripeAdapter.CustomerCreateAsync(options); + customer = await stripeAdapter.CreateCustomerAsync(options); user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; @@ -81,7 +80,7 @@ public class PremiumUserBillingService( Balance = customer.Balance + credit }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } @@ -101,7 +100,9 @@ public class PremiumUserBillingService( */ customer = await ReconcileBillingLocationAsync(customer, customerSetup.TaxInformation); - var subscription = await CreateSubscriptionAsync(user.Id, customer, storage); + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + + var subscription = await CreateSubscriptionAsync(user.Id, customer, premiumPlan, storage); switch (customerSetup.TokenizedPaymentSource) { @@ -119,6 +120,7 @@ public class PremiumUserBillingService( user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; user.GatewaySubscriptionId = subscription.Id; + user.MaxStorageGb = (short)(premiumPlan.Storage.Provided + (storage ?? 0)); await userRepository.ReplaceAsync(user); } @@ -224,7 +226,7 @@ public class PremiumUserBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -257,7 +259,7 @@ public class PremiumUserBillingService( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) @@ -301,9 +303,9 @@ public class PremiumUserBillingService( private async Task CreateSubscriptionAsync( Guid userId, Customer customer, + Pricing.Premium.Plan premiumPlan, int? storage) { - var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); var subscriptionItemOptionsList = new List { @@ -344,11 +346,11 @@ public class PremiumUserBillingService( OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); @@ -384,6 +386,6 @@ public class PremiumUserBillingService( } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs new file mode 100644 index 0000000000..cdc7645042 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -0,0 +1,209 @@ +// FIXME: Update this file to be null safe and then delete the line below + +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; +using Stripe.TestHelpers; +using CustomerService = Stripe.CustomerService; +using RefundService = Stripe.RefundService; + +namespace Bit.Core.Billing.Services.Implementations; + +public class StripeAdapter : IStripeAdapter +{ + private readonly CustomerService _customerService; + private readonly SubscriptionService _subscriptionService; + private readonly InvoiceService _invoiceService; + private readonly PaymentMethodService _paymentMethodService; + private readonly TaxIdService _taxIdService; + private readonly ChargeService _chargeService; + private readonly RefundService _refundService; + private readonly CardService _cardService; + private readonly BankAccountService _bankAccountService; + private readonly PriceService _priceService; + private readonly SetupIntentService _setupIntentService; + private readonly TestClockService _testClockService; + private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; + private readonly RegistrationService _taxRegistrationService; + + public StripeAdapter() + { + _customerService = new CustomerService(); + _subscriptionService = new SubscriptionService(); + _invoiceService = new InvoiceService(); + _paymentMethodService = new PaymentMethodService(); + _taxIdService = new TaxIdService(); + _chargeService = new ChargeService(); + _refundService = new RefundService(); + _cardService = new CardService(); + _bankAccountService = new BankAccountService(); + _priceService = new PriceService(); + _setupIntentService = new SetupIntentService(); + _testClockService = new TestClockService(); + _customerBalanceTransactionService = new CustomerBalanceTransactionService(); + _taxRegistrationService = new RegistrationService(); + } + + /************** + ** CUSTOMER ** + **************/ + public Task CreateCustomerAsync(CustomerCreateOptions options) => + _customerService.CreateAsync(options); + + public Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => + _customerService.DeleteDiscountAsync(customerId, options); + + public Task GetCustomerAsync(string id, CustomerGetOptions options = null) => + _customerService.GetAsync(id, options); + + public Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null) => + _customerService.UpdateAsync(id, options); + + public Task DeleteCustomerAsync(string id) => + _customerService.DeleteAsync(id); + + public async Task> ListCustomerPaymentMethodsAsync(string id, + CustomerPaymentMethodListOptions options = null) + { + var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); + return paymentMethods.Data; + } + + public Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options) => + _customerBalanceTransactionService.CreateAsync(customerId, options); + + /****************** + ** SUBSCRIPTION ** + ******************/ + public Task CreateSubscriptionAsync(SubscriptionCreateOptions options) => + _subscriptionService.CreateAsync(options); + + public Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null) => + _subscriptionService.GetAsync(id, options); + + public Task UpdateSubscriptionAsync(string id, + SubscriptionUpdateOptions options = null) => + _subscriptionService.UpdateAsync(id, options); + + public Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null) => + _subscriptionService.CancelAsync(id, options); + + /************* + ** INVOICE ** + *************/ + public Task GetInvoiceAsync(string id, InvoiceGetOptions options) => + _invoiceService.GetAsync(id, options); + + public async Task> ListInvoicesAsync(StripeInvoiceListOptions options) + { + if (!options.SelectAll) + { + return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; + } + + options.Limit = 100; + + var invoices = new List(); + + await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) + { + invoices.Add(invoice); + } + + return invoices; + } + + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => + _invoiceService.CreatePreviewAsync(options); + + public async Task> SearchInvoiceAsync(InvoiceSearchOptions options) => + (await _invoiceService.SearchAsync(options)).Data; + + public Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options) => + _invoiceService.UpdateAsync(id, options); + + public Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) => + _invoiceService.FinalizeInvoiceAsync(id, options); + + public Task SendInvoiceAsync(string id, InvoiceSendOptions options) => + _invoiceService.SendInvoiceAsync(id, options); + + public Task PayInvoiceAsync(string id, InvoicePayOptions options = null) => + _invoiceService.PayAsync(id, options); + + public Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null) => + _invoiceService.DeleteAsync(id, options); + + public Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null) => + _invoiceService.VoidInvoiceAsync(id, options); + + /******************** + ** PAYMENT METHOD ** + ********************/ + public IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options) => + _paymentMethodService.ListAutoPaging(options); + + public IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options) + => _paymentMethodService.ListAutoPagingAsync(options); + + public Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null) => + _paymentMethodService.AttachAsync(id, options); + + public Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null) => + _paymentMethodService.DetachAsync(id, options); + + /************ + ** TAX ID ** + ************/ + public Task CreateTaxIdAsync(string id, TaxIdCreateOptions options) => + _taxIdService.CreateAsync(id, options); + + public Task DeleteTaxIdAsync(string customerId, string taxIdId, + TaxIdDeleteOptions options = null) => + _taxIdService.DeleteAsync(customerId, taxIdId, options); + + /****************** + ** BANK ACCOUNT ** + ******************/ + public Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) => + _bankAccountService.DeleteAsync(customerId, bankAccount, options); + + /*********** + ** PRICE ** + ***********/ + public Task GetPriceAsync(string id, PriceGetOptions options = null) => + _priceService.GetAsync(id, options); + + /****************** + ** SETUP INTENT ** + ******************/ + public Task CreateSetupIntentAsync(SetupIntentCreateOptions options) => + _setupIntentService.CreateAsync(options); + + public async Task> ListSetupIntentsAsync(SetupIntentListOptions options) => + (await _setupIntentService.ListAsync(options)).Data; + + public Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null) => + _setupIntentService.CancelAsync(id, options); + + public Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null) => + _setupIntentService.GetAsync(id, options); + + /******************* + ** MISCELLANEOUS ** + *******************/ + public Task> ListChargesAsync(ChargeListOptions options) => + _chargeService.ListAsync(options); + + public Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null) => + _taxRegistrationService.ListAsync(options); + + public Task CreateRefundAsync(RefundCreateOptions options) => + _refundService.CreateAsync(options); + + public Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null) => + _cardService.DeleteAsync(customerId, cardId, options); +} diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Billing/Services/Implementations/StripePaymentService.cs similarity index 61% rename from src/Core/Services/Implementations/StripePaymentService.cs rename to src/Core/Billing/Services/Implementations/StripePaymentService.cs index 5dd1ff50e7..ffc18aa748 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Billing/Services/Implementations/StripePaymentService.cs @@ -8,11 +8,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; -using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Billing.Tax.Responses; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -25,9 +21,9 @@ using Stripe; using PaymentMethod = Stripe.PaymentMethod; using StaticStore = Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; -public class StripePaymentService : IPaymentService +public class StripePaymentService : IStripePaymentService { private const string SecretsManagerStandaloneDiscountId = "sm-standalone"; @@ -36,8 +32,6 @@ public class StripePaymentService : IPaymentService private readonly Braintree.IBraintreeGateway _btGateway; private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; - private readonly IFeatureService _featureService; - private readonly ITaxService _taxService; private readonly IPricingClient _pricingClient; public StripePaymentService( @@ -46,8 +40,6 @@ public class StripePaymentService : IPaymentService IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, - IFeatureService featureService, - ITaxService taxService, IPricingClient pricingClient) { _transactionRepository = transactionRepository; @@ -55,8 +47,6 @@ public class StripePaymentService : IPaymentService _stripeAdapter = stripeAdapter; _btGateway = braintreeGateway; _globalSettings = globalSettings; - _featureService = featureService; - _taxService = taxService; _pricingClient = pricingClient; } @@ -67,14 +57,14 @@ public class StripePaymentService : IPaymentService { var existingPlan = await _pricingClient.GetPlanOrThrow(org.PlanType); var sponsoredPlan = sponsorship?.PlanSponsorshipType != null - ? Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) + ? SponsoredPlans.Get(sponsorship.PlanSponsorshipType.Value) : null; var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true); - var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(org.GatewaySubscriptionId); org.ExpirationDate = sub.GetCurrentPeriodEnd(); if (sponsorship is not null) @@ -94,7 +84,7 @@ public class StripePaymentService : IPaymentService { // remember, when in doubt, throw var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] }; - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { throw new GatewayException("Subscription not found."); @@ -117,7 +107,7 @@ public class StripePaymentService : IPaymentService var subUpdateOptions = new SubscriptionUpdateOptions { Items = updatedItemOptions, - ProrationBehavior = invoiceNow ? Constants.AlwaysInvoice : Constants.CreateProrations, + ProrationBehavior = invoiceNow ? Core.Constants.AlwaysInvoice : Core.Constants.CreateProrations, DaysUntilDue = daysUntilDue ?? 1, CollectionMethod = "send_invoice" }; @@ -131,11 +121,11 @@ public class StripePaymentService : IPaymentService { if (sub.Customer is { - Address.Country: not Constants.CountryAbbreviations.UnitedStates, + Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse }) { - await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, + await _stripeAdapter.UpdateCustomerAsync(sub.CustomerId, new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); } @@ -151,9 +141,9 @@ public class StripePaymentService : IPaymentService string paymentIntentClientSecret = null; try { - var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); + var subResponse = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, subUpdateOptions); - var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); + var invoice = await _stripeAdapter.GetInvoiceAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); if (invoice == null) { throw new BadRequestException("Unable to locate draft invoice for subscription update."); @@ -172,9 +162,9 @@ public class StripePaymentService : IPaymentService } else { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, + invoice = await _stripeAdapter.FinalizeInvoiceAsync(subResponse.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = false, }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); + await _stripeAdapter.SendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); paymentIntentClientSecret = null; } } @@ -182,7 +172,7 @@ public class StripePaymentService : IPaymentService catch { // Need to revert the subscription - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { Items = subscriptionUpdate.RevertItemsOptions(sub), // This proration behavior prevents a false "credit" from @@ -197,7 +187,7 @@ public class StripePaymentService : IPaymentService else if (invoice.Status != StripeConstants.InvoiceStatus.Paid) { // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h - invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); + invoice = await _stripeAdapter.PayInvoiceAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } } @@ -206,7 +196,7 @@ public class StripePaymentService : IPaymentService // Change back the subscription collection method and/or days until due if (collectionMethod != "send_invoice" || daysUntilDue == null) { - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CollectionMethod = collectionMethod, @@ -214,14 +204,14 @@ public class StripePaymentService : IPaymentService }); } - var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(sub.CustomerId); var newCoupon = customer.Discount?.Coupon?.Id; if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon)) { // Re-add the lost coupon due to the update. - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { Discounts = [ @@ -294,7 +284,7 @@ public class StripePaymentService : IPaymentService { if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, + await _stripeAdapter.CancelSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionCancelOptions()); } @@ -303,7 +293,7 @@ public class StripePaymentService : IPaymentService return; } - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); if (customer == null) { return; @@ -328,7 +318,7 @@ public class StripePaymentService : IPaymentService } else { - var charges = await _stripeAdapter.ChargeListAsync(new ChargeListOptions + var charges = await _stripeAdapter.ListChargesAsync(new ChargeListOptions { Customer = subscriber.GatewayCustomerId }); @@ -337,12 +327,12 @@ public class StripePaymentService : IPaymentService { foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) { - await _stripeAdapter.RefundCreateAsync(new RefundCreateOptions { Charge = charge.Id }); + await _stripeAdapter.CreateRefundAsync(new RefundCreateOptions { Charge = charge.Id }); } } } - await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); + await _stripeAdapter.DeleteCustomerAsync(subscriber.GatewayCustomerId); } public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Invoice invoice) @@ -350,7 +340,7 @@ public class StripePaymentService : IPaymentService var customerOptions = new CustomerGetOptions(); customerOptions.AddExpand("default_source"); customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); string paymentIntentClientSecret = null; @@ -370,13 +360,13 @@ public class StripePaymentService : IPaymentService // We're going to delete this draft invoice, it can't be paid try { - await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + await _stripeAdapter.DeleteInvoiceAsync(invoice.Id); } catch { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions { AutoAdvance = false }); - await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + await _stripeAdapter.VoidInvoiceAsync(invoice.Id); } throw new BadRequestException("No payment method is available."); @@ -389,7 +379,7 @@ public class StripePaymentService : IPaymentService { // Finalize the invoice (from Draft) w/o auto-advance so we // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + invoice = await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions { AutoAdvance = false, }); var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, }; if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) @@ -424,7 +414,7 @@ public class StripePaymentService : IPaymentService } braintreeTransaction = transactionResult.Target; - invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new InvoiceUpdateOptions + invoice = await _stripeAdapter.UpdateInvoiceAsync(invoice.Id, new InvoiceUpdateOptions { Metadata = new Dictionary { @@ -438,7 +428,7 @@ public class StripePaymentService : IPaymentService try { - invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); + invoice = await _stripeAdapter.PayInvoiceAsync(invoice.Id, invoicePayOptions); } catch (StripeException e) { @@ -448,7 +438,7 @@ public class StripePaymentService : IPaymentService // SCA required, get intent client secret var invoiceGetOptions = new InvoiceGetOptions(); invoiceGetOptions.AddExpand("confirmation_secret"); - invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); + invoice = await _stripeAdapter.GetInvoiceAsync(invoice.Id, invoiceGetOptions); paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret; } else @@ -472,7 +462,7 @@ public class StripePaymentService : IPaymentService return paymentIntentClientSecret; } - invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); + invoice = await _stripeAdapter.VoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); // HACK: Workaround for customer balance credit if (invoice.StartingBalance < 0) @@ -480,12 +470,12 @@ public class StripePaymentService : IPaymentService // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to // credit it back to the customer (even though their docs claim they will), we need to // check that balance against the current customer balance and determine if it needs to be re-applied - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); // Assumption: Customer balance should now be $0, otherwise payment would not have failed. if (customer.Balance == 0) { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Balance = invoice.StartingBalance }); } } @@ -516,7 +506,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -532,9 +522,9 @@ public class StripePaymentService : IPaymentService try { var canceledSub = endOfPeriod - ? await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + ? await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) - : await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions()); + : await _stripeAdapter.CancelSubscriptionAsync(sub.Id, new SubscriptionCancelOptions()); if (!canceledSub.CanceledAt.HasValue) { throw new GatewayException("Unable to cancel subscription."); @@ -561,7 +551,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -573,7 +563,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("Subscription is not marked for cancellation."); } - var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + var updatedSub = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); if (updatedSub.CanceledAt.HasValue) { @@ -588,11 +578,11 @@ public class StripePaymentService : IPaymentService !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); if (customerExists) { - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); } else { - customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions + customer = await _stripeAdapter.CreateCustomerAsync(new CustomerCreateOptions { Email = subscriber.BillingEmailAddress(), Description = subscriber.BillingName(), @@ -601,9 +591,8 @@ public class StripePaymentService : IPaymentService subscriber.GatewayCustomerId = customer.Id; } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) }); - return !customerExists; } @@ -640,7 +629,7 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, + var subscription = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.discount.coupon.applies_to", "discounts.coupon.applies_to", "test_clock"] }); if (subscription == null) @@ -685,7 +674,7 @@ public class StripePaymentService : IPaymentService Subscription = subscriber.GatewaySubscriptionId }; - var upcomingInvoice = await _stripeAdapter.InvoiceCreatePreviewAsync(invoiceCreatePreviewOptions); + var upcomingInvoice = await _stripeAdapter.CreateInvoicePreviewAsync(invoiceCreatePreviewOptions); if (upcomingInvoice != null) { @@ -705,133 +694,6 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - public async Task GetTaxInfoAsync(ISubscriber subscriber) - { - if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, - new CustomerGetOptions { Expand = ["tax_ids"] }); - - if (customer == null) - { - return null; - } - - var address = customer.Address; - var taxId = customer.TaxIds?.FirstOrDefault(); - - // Line1 is required, so if missing we're using the subscriber name, - // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 - if (address != null && string.IsNullOrWhiteSpace(address.Line1)) - { - address.Line1 = null; - } - - return new TaxInfo - { - TaxIdNumber = taxId?.Value, - TaxIdType = taxId?.Type, - BillingAddressLine1 = address?.Line1, - BillingAddressLine2 = address?.Line2, - BillingAddressCity = address?.City, - BillingAddressState = address?.State, - BillingAddressPostalCode = address?.PostalCode, - BillingAddressCountry = address?.Country, - }; - } - - public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) - { - if (string.IsNullOrWhiteSpace(subscriber?.GatewayCustomerId) || subscriber.IsUser()) - { - return; - } - - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions - { - Address = new AddressOptions - { - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - PostalCode = taxInfo.BillingAddressPostalCode, - Country = taxInfo.BillingAddressCountry, - }, - Expand = ["tax_ids"] - }); - - if (customer == null) - { - return; - } - - var taxId = customer.TaxIds?.FirstOrDefault(); - - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - - if (string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) - { - return; - } - - var taxIdType = taxInfo.TaxIdType; - - if (string.IsNullOrWhiteSpace(taxIdType)) - { - taxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); - - if (taxIdType == null) - { - _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - } - - try - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }); - - if (taxInfo.TaxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{taxInfo.TaxIdNumber}" - }); - } - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry); - throw new BadRequestException("billingInvalidTaxIdError"); - default: - _logger.LogError(e, - "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry, - customer.Id); - throw new BadRequestException("billingTaxIdCreationError"); - } - } - } - public async Task AddSecretsManagerToSubscription( Organization org, StaticStore.Plan plan, @@ -863,7 +725,7 @@ public class StripePaymentService : IPaymentService return false; } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); return customer?.Discount?.Coupon?.Id == SecretsManagerStandaloneDiscountId; } @@ -875,7 +737,7 @@ public class StripePaymentService : IPaymentService return (null, null); } - var openInvoices = await _stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await _stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -909,312 +771,9 @@ public class StripePaymentService : IPaymentService } } - [Obsolete($"Use {nameof(PreviewPremiumTaxCommand)} instead.")] - public async Task PreviewInvoiceAsync( - PreviewIndividualInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); - - var options = new InvoiceCreatePreviewOptions - { - AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true, }, - Currency = "usd", - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new InvoiceSubscriptionDetailsItemOptions - { - Quantity = 1, - Plan = premiumPlan.Seat.StripePriceId - }, - - new InvoiceSubscriptionDetailsItemOptions - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = premiumPlan.Storage.StripePriceId - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - } - - options.CustomerDetails.TaxIds = - [ - new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Discounts = [new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }]; - } - } - - if (!string.IsNullOrWhiteSpace(gatewaySubscriptionId)) - { - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discounts is { Count: > 0 }) - { - options.Discounts = gatewaySubscription.Discounts.Select(x => new InvoiceDiscountOptions { Coupon = x.Coupon.Id }).ToList(); - } - } - - if (options.Discounts is { Count: > 0 }) - { - options.Discounts = options.Discounts.DistinctBy(invoiceDiscountOptions => invoiceDiscountOptions.Coupon).ToList(); - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount); - - var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - tax.ToMajor(), - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, - "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - - public async Task PreviewInvoiceAsync( - PreviewOrganizationInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var plan = await _pricingClient.GetPlanOrThrow(parameters.PasswordManager.Plan); - var isSponsored = parameters.PasswordManager.SponsoredPlan.HasValue; - - var options = new InvoiceCreatePreviewOptions - { - Currency = "usd", - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new InvoiceSubscriptionDetailsItemOptions - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = plan.PasswordManager.StripeStoragePlanId - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (isSponsored) - { - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(parameters.PasswordManager.SponsoredPlan.Value); - options.SubscriptionDetails.Items.Add( - new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = sponsoredPlan.StripePlanId } - ); - } - else - { - if (plan.PasswordManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add( - new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.Seats, Plan = plan.PasswordManager.StripeSeatPlanId } - ); - } - else - { - options.SubscriptionDetails.Items.Add( - new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = plan.PasswordManager.StripePlanId } - ); - } - - if (plan.SupportsSecretsManager) - { - if (plan.SecretsManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions - { - Quantity = parameters.SecretsManager?.Seats ?? 0, - Plan = plan.SecretsManager.StripeSeatPlanId - }); - } - - if (plan.SecretsManager.HasAdditionalServiceAccountOption) - { - options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions - { - Quantity = parameters.SecretsManager?.AdditionalMachineAccounts ?? 0, - Plan = plan.SecretsManager.StripeServiceAccountPlanId - }); - } - } - } - - if (!string.IsNullOrWhiteSpace(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - - options.CustomerDetails.TaxIds = - [ - new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - Customer gatewayCustomer = null; - - if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Discounts = - [ - new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id } - ]; - } - } - - if (!string.IsNullOrWhiteSpace(gatewaySubscriptionId)) - { - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discounts != null) - { - options.Discounts = gatewaySubscription.Discounts - .Select(discount => new InvoiceDiscountOptions { Coupon = discount.Coupon.Id }).ToList(); - } - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - if (parameters.PasswordManager.Plan.IsBusinessProductTierType() && - parameters.TaxInformation.Country != Constants.CountryAbbreviations.UnitedStates) - { - options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount); - - var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - tax.ToMajor(), - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, - "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - private PaymentMethod GetLatestCardPaymentMethod(string customerId) { - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( + var cardPaymentMethods = _stripeAdapter.ListPaymentMethodsAutoPaging( new PaymentMethodListOptions { Customer = customerId, Type = "card" }); return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); } @@ -1277,7 +836,7 @@ public class StripePaymentService : IPaymentService Customer customer = null; try { - customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); + customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId, options); } catch (StripeException) { @@ -1310,21 +869,21 @@ public class StripePaymentService : IPaymentService try { - var paidInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var paidInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "paid" }); - var openInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var openInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "open" }); - var uncollectibleInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var uncollectibleInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, diff --git a/src/Core/Services/Implementations/StripeSyncService.cs b/src/Core/Billing/Services/Implementations/StripeSyncService.cs similarity index 68% rename from src/Core/Services/Implementations/StripeSyncService.cs rename to src/Core/Billing/Services/Implementations/StripeSyncService.cs index b2700e65d1..31dd89d72d 100644 --- a/src/Core/Services/Implementations/StripeSyncService.cs +++ b/src/Core/Billing/Services/Implementations/StripeSyncService.cs @@ -1,6 +1,6 @@ using Bit.Core.Exceptions; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; public class StripeSyncService : IStripeSyncService { @@ -11,7 +11,7 @@ public class StripeSyncService : IStripeSyncService _stripeAdapter = stripeAdapter; } - public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) + public async Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress) { if (string.IsNullOrWhiteSpace(gatewayCustomerId)) { @@ -23,9 +23,9 @@ public class StripeSyncService : IStripeSyncService throw new InvalidEmailException(); } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new Stripe.CustomerUpdateOptions { Email = emailAddress }); } } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 8e75bf3dca..7acbe20014 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -15,7 +15,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -24,7 +23,6 @@ using Stripe; using static Bit.Core.Billing.Utilities; using Customer = Stripe.Customer; -using PaymentMethod = Bit.Core.Billing.Models.PaymentMethod; using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Services.Implementations; @@ -79,7 +77,7 @@ public class SubscriberService( { if (subscription.Metadata != null && subscription.Metadata.ContainsKey("organizationId")) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Metadata = metadata }); @@ -98,7 +96,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionCancelAsync(subscription.Id, options); + await stripeAdapter.CancelSubscriptionAsync(subscription.Id, options); } else { @@ -117,7 +115,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, options); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); } } @@ -228,7 +226,7 @@ public class SubscriberService( _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) }; - var customer = await stripeAdapter.CustomerCreateAsync(options); + var customer = await stripeAdapter.CreateCustomerAsync(options); switch (subscriber) { @@ -271,7 +269,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -307,7 +305,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -330,38 +328,6 @@ public class SubscriberService( } } - public async Task GetPaymentMethod( - ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - var customer = await GetCustomer(subscriber, new CustomerGetOptions - { - Expand = ["default_source", "invoice_settings.default_payment_method", "subscriptions", "tax_ids"] - }); - - if (customer == null) - { - return PaymentMethod.Empty; - } - - var accountCredit = customer.Balance * -1 / 100M; - - var paymentMethod = await GetPaymentSourceAsync(subscriber.Id, customer); - - var subscriptionStatus = customer.Subscriptions - .FirstOrDefault(subscription => subscription.Id == subscriber.GatewaySubscriptionId)? - .Status; - - var taxInformation = GetTaxInformation(customer); - - return new PaymentMethod( - accountCredit, - paymentMethod, - subscriptionStatus, - taxInformation); - } - public async Task GetPaymentSource( ISubscriber subscriber) { @@ -390,7 +356,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -426,7 +392,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -449,16 +415,6 @@ public class SubscriberService( } } - public async Task GetTaxInformation( - ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions { Expand = ["tax_ids"] }); - - return GetTaxInformation(customer); - } - public async Task RemovePaymentSource( ISubscriber subscriber) { @@ -530,23 +486,23 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(stripeCustomer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(stripeCustomer.Id, source.Id); break; } } } - var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new PaymentMethodListOptions + var paymentMethods = stripeAdapter.ListPaymentMethodsAutoPagingAsync(new PaymentMethodListOptions { Customer = stripeCustomer.Id }); await foreach (var paymentMethod in paymentMethods) { - await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id); + await stripeAdapter.DetachPaymentMethodAsync(paymentMethod.Id); } } } @@ -575,7 +531,7 @@ public class SubscriberService( { case PaymentMethodType.BankAccount: { - var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.SetupIntentList(new SetupIntentListOptions + var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = token }); @@ -612,7 +568,7 @@ public class SubscriberService( await RemoveStripePaymentMethodsAsync(customer); // Attach the incoming payment method. - await stripeAdapter.PaymentMethodAttachAsync(token, + await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); var metadata = customer.Metadata; @@ -624,7 +580,7 @@ public class SubscriberService( } // Set the customer's default payment method in Stripe and remove their Braintree customer ID. - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -687,7 +643,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -705,7 +661,7 @@ public class SubscriberService( if (taxId != null) { - await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + await stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id); } if (!string.IsNullOrWhiteSpace(taxInformation.TaxId)) @@ -728,12 +684,12 @@ public class SubscriberService( try { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = taxIdType, Value = taxInformation.TaxId }); if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInformation.TaxId}" }); } } @@ -779,7 +735,7 @@ public class SubscriberService( Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); break; case @@ -787,14 +743,14 @@ public class SubscriberService( Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.None }); break; } if (!subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -814,7 +770,7 @@ public class SubscriberService( if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -823,57 +779,6 @@ public class SubscriberService( } } - public async Task VerifyBankAccount( - ISubscriber subscriber, - string descriptorCode) - { - var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id); - - if (string.IsNullOrEmpty(setupIntentId)) - { - logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id); - throw new BillingException(); - } - - try - { - await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, - new SetupIntentVerifyMicrodepositsOptions { DescriptorCode = descriptorCode }); - - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId); - - await stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId, - new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); - - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions - { - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = setupIntent.PaymentMethodId - } - }); - } - catch (StripeException stripeException) - { - if (!string.IsNullOrEmpty(stripeException.StripeError?.Code)) - { - var message = stripeException.StripeError.Code switch - { - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded => "You have exceeded the number of allowed verification attempts. Please contact support.", - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch => "The verification code you provided does not match the one sent to your bank account. Please try again.", - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationTimeout => "Your bank account was not verified within the required time period. Please contact support.", - _ => BillingException.DefaultMessage - }; - - throw new BadRequestException(message); - } - - logger.LogError(stripeException, "An unhandled Stripe exception was thrown while verifying subscriber's ({SubscriberID}) bank account", subscriber.Id); - throw new BillingException(); - } - } - public async Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber) { ArgumentNullException.ThrowIfNull(subscriber); @@ -884,7 +789,7 @@ public class SubscriberService( } try { - await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -903,7 +808,7 @@ public class SubscriberService( } try { - await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -922,7 +827,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdKey] = braintreeCustomerId; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -962,7 +867,7 @@ public class SubscriberService( return null; } - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); @@ -970,25 +875,6 @@ public class SubscriberService( return PaymentSource.From(setupIntent); } - private static TaxInformation GetTaxInformation( - Customer customer) - { - if (customer.Address == null) - { - return null; - } - - return new TaxInformation( - customer.Address.Country, - customer.Address.PostalCode, - customer.TaxIds?.FirstOrDefault()?.Value, - customer.TaxIds?.FirstOrDefault()?.Type, - customer.Address.Line1, - customer.Address.Line2, - customer.Address.City, - customer.Address.State); - } - private async Task RemoveBraintreeCustomerIdAsync( Customer customer) { @@ -999,7 +885,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdOldKey] = value; metadata[BraintreeCustomerIdKey] = null; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -1016,18 +902,18 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(customer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(customer.Id, source.Id); break; } } } - var paymentMethods = await stripeAdapter.CustomerListPaymentMethods(customer.Id); + var paymentMethods = await stripeAdapter.ListCustomerPaymentMethodsAsync(customer.Id); - await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.PaymentMethodDetachAsync(pm.Id))); + await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.DetachPaymentMethodAsync(pm.Id))); } private async Task ReplaceBraintreePaymentMethodAsync( diff --git a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs index ee60597601..7f7be9d1eb 100644 --- a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs +++ b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs @@ -7,7 +7,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Repositories; -using Bit.Core.Services; using OneOf.Types; using Stripe; @@ -53,7 +52,7 @@ public class RestartSubscriptionCommand( TrialPeriodDays = 0 }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(options); + var subscription = await stripeAdapter.CreateSubscriptionAsync(options); await EnableAsync(subscriber, subscription); return new None(); } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 2ee6b75664..ec5978988c 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -2,8 +2,8 @@ #nullable disable using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Services; using Stripe; namespace Bit.Core.Billing; @@ -22,7 +22,7 @@ public static class Utilities return null; } - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index d41548b5d8..6a88903a82 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -140,8 +140,10 @@ public static class FeatureFlagKeys public const string CreateDefaultLocation = "pm-19467-create-default-location"; public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; - public const string AccountRecoveryCommand = "pm-25581-prevent-provider-account-recovery"; - public const string PolicyValidatorsRefactor = "pm-26423-refactor-policy-side-effects"; + public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration"; + public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud"; + public const string BulkRevokeUsersV2 = "pm-28456-bulk-revoke-users-v2"; + public const string PremiumAccessQuery = "pm-21411-premium-access-query"; /* Architecture */ public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1"; @@ -155,13 +157,14 @@ public static class FeatureFlagKeys public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor"; public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; public const string Otp6Digits = "pm-18612-otp-6-digits"; - public const string FailedTwoFactorEmail = "pm-24425-send-2fa-failed-email"; public const string PM24579_PreventSsoOnExistingNonCompliantUsers = "pm-24579-prevent-sso-on-existing-non-compliant-users"; public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; public const string PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword = "pm-23174-manage-account-recovery-permission-drives-the-need-to-set-master-password"; - public const string RecoveryCodeSupportForSsoRequiredUsers = "pm-21153-recovery-code-support-for-sso-required"; public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; + public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; + public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; + public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; /* Autofill Team */ public const string IdpAutoSubmitLogin = "idp-auto-submit-login"; @@ -183,9 +186,6 @@ public static class FeatureFlagKeys /* Billing Team */ public const string TrialPayment = "PM-8163-trial-payment"; - public const string UsePricingService = "use-pricing-service"; - public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; - public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; public const string PM25379_UseNewOrganizationMetadataStructure = "pm-25379-use-new-organization-metadata-structure"; public const string PM24996ImplementUpgradeFromFreeDialog = "pm-24996-implement-upgrade-from-free-dialog"; @@ -195,17 +195,14 @@ public static class FeatureFlagKeys public const string PM26793_FetchPremiumPriceFromPricingService = "pm-26793-fetch-premium-price-from-pricing-service"; public const string PM23341_Milestone_2 = "pm-23341-milestone-2"; public const string PM26462_Milestone_3 = "pm-26462-milestone-3"; + public const string PM28265_EnableReconcileAdditionalStorageJob = "pm-28265-enable-reconcile-additional-storage-job"; + public const string PM28265_ReconcileAdditionalStorageJobEnableLiveMode = "pm-28265-reconcile-additional-storage-job-enable-live-mode"; /* Key Management Team */ - public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; - public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service"; public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration"; public const string Argon2Default = "argon2-default"; - public const string UserkeyRotationV2 = "userkey-rotation-v2"; public const string SSHKeyItemVaultItem = "ssh-key-vault-item"; - public const string UserSdkForDecryption = "use-sdk-for-decryption"; public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation"; - public const string PM17987_BlockType0 = "pm-17987-block-type-0"; public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings"; public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data"; public const string WindowsBiometricsV2 = "pm-25373-windows-biometrics-v2"; @@ -213,6 +210,8 @@ public static class FeatureFlagKeys public const string NoLogoutOnKdfChange = "pm-23995-no-logout-on-kdf-change"; public const string DisableType0Decryption = "pm-25174-disable-type-0-decryption"; public const string ConsolidatedSessionTimeoutComponent = "pm-26056-consolidated-session-timeout-component"; + public const string V2RegistrationTDEJIT = "pm-27279-v2-registration-tde-jit"; + public const string DataRecoveryTool = "pm-28813-data-recovery-tool"; /* Mobile Team */ public const string AndroidImportLoginsFlow = "import-logins-flow"; @@ -242,6 +241,7 @@ public static class FeatureFlagKeys public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators"; public const string UseChromiumImporter = "pm-23982-chromium-importer"; public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe"; + public const string SendUIRefresh = "pm-28175-send-ui-refresh"; /// /// Enable this flag to output email/OTP authenticated sends from the `GET sends` endpoint. When @@ -254,19 +254,17 @@ public static class FeatureFlagKeys public const string PM19051_ListEmailOtpSends = "tools-send-email-otp-listing"; /* Vault Team */ - public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge"; - public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; public const string CipherKeyEncryption = "cipher-key-encryption"; public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk"; - public const string EndUserNotifications = "pm-10609-end-user-notifications"; public const string PhishingDetection = "phishing-detection"; public const string RemoveCardItemTypePolicy = "pm-16442-remove-card-item-type-policy"; public const string PM22134SdkCipherListView = "pm-22134-sdk-cipher-list-view"; - public const string PM19315EndUserActivationMvp = "pm-19315-end-user-activation-mvp"; public const string PM22136_SdkCipherEncryption = "pm-22136-sdk-cipher-encryption"; public const string PM23904_RiskInsightsForPremium = "pm-23904-risk-insights-for-premium"; public const string PM25083_AutofillConfirmFromSearch = "pm-25083-autofill-confirm-from-search"; public const string VaultLoadingSkeletons = "pm-25081-vault-skeleton-loaders"; + public const string BrowserPremiumSpotlight = "pm-23384-browser-premium-spotlight"; + public const string MigrateMyVaultToMyItems = "pm-20558-migrate-myvault-to-myitems"; /* Innovation Team */ public const string ArchiveVaultItems = "pm-19148-innovation-archive"; @@ -276,6 +274,9 @@ public static class FeatureFlagKeys public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike"; public const string EventDiagnosticLogging = "pm-27666-siem-event-log-debugging"; + /* UIF Team */ + public const string RouterFocusManagement = "router-focus-management"; + public static List GetAllKeys() { return typeof(FeatureFlagKeys).GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index 5d9b5a1759..6067c60556 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -38,10 +38,6 @@ public class CurrentContext( public virtual List Providers { get; set; } public virtual Guid? InstallationId { get; set; } public virtual Guid? OrganizationId { get; set; } - public virtual bool CloudflareWorkerProxied { get; set; } - public virtual bool IsBot { get; set; } - public virtual bool MaybeBot { get; set; } - public virtual int? BotScore { get; set; } public virtual string ClientId { get; set; } public virtual Version ClientVersion { get; set; } public virtual bool ClientVersionIsPrerelease { get; set; } @@ -70,27 +66,6 @@ public class CurrentContext( DeviceType = dType; } - if (!BotScore.HasValue && httpContext.Request.Headers.TryGetValue("X-Cf-Bot-Score", out var cfBotScore) && - int.TryParse(cfBotScore, out var parsedBotScore)) - { - BotScore = parsedBotScore; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Worked-Proxied", out var cfWorkedProxied)) - { - CloudflareWorkerProxied = cfWorkedProxied == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Is-Bot", out var cfIsBot)) - { - IsBot = cfIsBot == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Maybe-Bot", out var cfMaybeBot)) - { - MaybeBot = cfMaybeBot == "1"; - } - if (httpContext.Request.Headers.TryGetValue("Bitwarden-Client-Version", out var bitWardenClientVersion) && Version.TryParse(bitWardenClientVersion, out var cVersion)) { ClientVersion = cVersion; diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index f62a048070..d527cdd363 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -31,9 +31,6 @@ public interface ICurrentContext Guid? InstallationId { get; set; } Guid? OrganizationId { get; set; } IdentityClientType IdentityClientType { get; set; } - bool IsBot { get; set; } - bool MaybeBot { get; set; } - int? BotScore { get; set; } string ClientId { get; set; } Version ClientVersion { get; set; } bool ClientVersionIsPrerelease { get; set; } diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 4901c5b43c..52c0a641ab 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -23,14 +23,14 @@ - - - + + + - - - + + + @@ -41,6 +41,7 @@ + @@ -50,24 +51,23 @@ - - - - - + - - - - + + + + + + + diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index fec9b80d8e..669e32bcbe 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -69,6 +69,11 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac /// The security state is a signed object attesting to the version of the user's account. /// public string? SecurityState { get; set; } + /// + /// Indicates whether the user has a personal premium subscription. + /// Does not include premium access from organizations - + /// do not use this to check whether the user can access premium features. + /// public bool Premium { get; set; } public DateTime? PremiumExpirationDate { get; set; } public DateTime? RenewalReminderDate { get; set; } @@ -200,11 +205,6 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac return Id; } - public bool GetPremium() - { - return Premium; - } - public int GetSecurityVersion() { // If no security version is set, it is version 1. The minimum initialized version is 2. diff --git a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs index 0e551c5d0e..abaf9406ba 100644 --- a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs +++ b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs @@ -26,5 +26,6 @@ public static class KeyManagementServiceCollectionExtensions private static void AddKeyManagementQueries(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); } } diff --git a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs similarity index 92% rename from src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs index b64e826911..bdf538e6d8 100644 --- a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs @@ -1,8 +1,7 @@ -using Bit.Core.KeyManagement.Models.Api.Request; -using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; public class AccountKeysRequestModel { diff --git a/src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs similarity index 91% rename from src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs index 24c1e6a946..f9b009f7e2 100644 --- a/src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs @@ -1,7 +1,7 @@ using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; public class PublicKeyEncryptionKeyPairRequestModel { diff --git a/src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs similarity index 93% rename from src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs index 3cdb4f53f1..a569bc70ab 100644 --- a/src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs @@ -1,7 +1,7 @@ using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; public class SignatureKeyPairRequestModel { diff --git a/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs b/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs new file mode 100644 index 0000000000..3821831bad --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.KeyManagement.Models.Data; + +public class KeyConnectorConfirmationDetails +{ + public required string OrganizationName { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs index cabdca59ea..3d552a10de 100644 --- a/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs +++ b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs @@ -1,9 +1,34 @@ namespace Bit.Core.KeyManagement.Models.Data; - +/// +/// Represents an expanded account cryptographic state for a user. Expanded here means +/// that it does not only contain the (wrapped) private / signing key, but also the public +/// key / verifying key. The client side only needs a subset of this data to unlock +/// their vault and the public parts can be derived. +/// public class UserAccountKeysData { public required PublicKeyEncryptionKeyPairData PublicKeyEncryptionKeyPairData { get; set; } public SignatureKeyPairData? SignatureKeyPairData { get; set; } public SecurityStateData? SecurityStateData { get; set; } + + /// + /// Checks whether the account cryptographic state is for a V1 encryption user or a V2 encryption user. + /// Throws if the state is invalid + /// + public bool IsV2Encryption() + { + if (PublicKeyEncryptionKeyPairData.SignedPublicKey != null && SignatureKeyPairData != null && SecurityStateData != null) + { + return true; + } + else if (PublicKeyEncryptionKeyPairData.SignedPublicKey == null && SignatureKeyPairData == null && SecurityStateData == null) + { + return false; + } + else + { + throw new InvalidOperationException("Invalid account cryptographic state: V2 encryption fields must be either all present or all absent."); + } + } } diff --git a/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs b/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs new file mode 100644 index 0000000000..60b78c03f4 --- /dev/null +++ b/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs @@ -0,0 +1,8 @@ +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Queries.Interfaces; + +public interface IKeyConnectorConfirmationDetailsQuery +{ + public Task Run(string orgSsoIdentifier, Guid userId); +} diff --git a/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs b/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs new file mode 100644 index 0000000000..0c210e2fd1 --- /dev/null +++ b/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs @@ -0,0 +1,35 @@ +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Repositories; + +namespace Bit.Core.KeyManagement.Queries; + +public class KeyConnectorConfirmationDetailsQuery : IKeyConnectorConfirmationDetailsQuery +{ + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + + public KeyConnectorConfirmationDetailsQuery(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository) + { + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + } + + public async Task Run(string orgSsoIdentifier, Guid userId) + { + var org = await _organizationRepository.GetByIdentifierAsync(orgSsoIdentifier); + if (org is not { UseKeyConnector: true }) + { + throw new NotFoundException(); + } + + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, userId); + if (orgUser == null) + { + throw new NotFoundException(); + } + + return new KeyConnectorConfirmationDetails { OrganizationName = org.Name, }; + } +} diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs index fad0af840d..f9cc04f73e 100644 --- a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs @@ -29,8 +29,8 @@ .mj-outlook-group-fix { width:100% !important; } - - + + - - - - + + + + - + - + - - + +
    - + - - + +
    - +
    - +
    - - + + - - + +
    - +
    - +
    - + - + - + - +
    - +
    - + - +
    - +
    - +

    Verify your email to access this Bitwarden Send

    - +
    - +
    - + - +
    - + - + - - +
    - + +
    - + - +
    - +
    - +
    - +
    - +
    - - + + - - + +
    - +
    - +
    - - + + - + - + - - + +
    - +
    - - + +
    - +
    - +
    - +
    - + - + - + - + - + - +
    - +
    Your verification code is:
    - +
    - +
    {{Token}}
    - +
    - +
    - +
    - -
    This code expires in {{Expiry}} minutes. After that, you'll need to - verify your email again.
    - + +
    This code expires in {{Expiry}} minutes. After that, you'll need + to verify your email again.
    +
    - +
    - +
    - +
    - +
    - - + + - - + +
    - +
    - +
    - +
    - + - + - +
    - +

    Bitwarden Send transmits sensitive, temporary information to others easily and securely. Learn more about @@ -325,160 +333,160 @@ sign up to try it today.

    - +
    - +
    - +
    - +
    - +
    - - + +
    - +
    - - + + - + - + - - + +
    - +
    - - + +
    - +
    - +
    - + - + - +
    - +

    - Learn more about Bitwarden -

    - Find user guides, product documentation, and videos on the - Bitwarden Help Center.
    - + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center. +
    - +
    - + - +
    - + - + - - +
    - +
    - +
    - +
    - +
    - - + +
    - +
    - - + + - + - + - - + +
    - +
    - +
    - + - + - + - +
    - - + + - + - + - +
    @@ -493,15 +501,15 @@
    - + - + - +
    @@ -516,15 +524,15 @@
    - + - + - +
    @@ -539,15 +547,15 @@
    - + - + - +
    @@ -562,15 +570,15 @@
    - + - + - +
    @@ -585,15 +593,15 @@
    - + - + - +
    @@ -608,15 +616,15 @@
    - + - + - +
    @@ -631,20 +639,20 @@
    - - + +
    - +

    © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA @@ -655,28 +663,29 @@ bitwarden.com | Learn why we include this

    - +
    - +
    - +
    - +
    - - + + - - + +
    - + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs new file mode 100644 index 0000000000..65e37e87dd --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs @@ -0,0 +1,815 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + You can now share passwords with members of {{OrganizationName}}! +

    + +
    + + + + + + + +
    + + Log in + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    As a member of {{OrganizationName}}:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Organization Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    Your account is owned by {{OrganizationName}} and is subject to their security and management policies.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Group Users Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + +
    You can easily access and share passwords with your team.
    + +
    + + + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs new file mode 100644 index 0000000000..38c45f2dd1 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs @@ -0,0 +1,4 @@ +{{#>TitleContactUsTextLayout}} + You may now access logins and other items {{OrganizationName}} has shared with you from your Bitwarden vault. + Tip: Use the Bitwarden mobile app to quickly save logins and auto-fill forms. Download from the App Store or Google Play. +{{/TitleContactUsTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs new file mode 100644 index 0000000000..c22bc80a51 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs @@ -0,0 +1,983 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + You can now share passwords with members of {{OrganizationName}}! +

    + +
    + + + + + + + +
    + + Log in + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    As a member of {{OrganizationName}}:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Collections Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    You can access passwords {{OrganizationName}} has shared with you.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Group Users Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + +
    You can easily share passwords with friends, family, or coworkers.
    + +
    + + + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + +
    Download Bitwarden on all devices
    + +
    + +
    Already using the browser extension? + Download the Bitwarden mobile app from the + App Store + or Google Play + to quickly save logins and autofill forms on the go.
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + + Download on the App Store + + + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + + Get it on Google Play + + + +
    + +
    + +
    + + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs new file mode 100644 index 0000000000..38c45f2dd1 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs @@ -0,0 +1,4 @@ +{{#>TitleContactUsTextLayout}} + You may now access logins and other items {{OrganizationName}} has shared with you from your Bitwarden vault. + Tip: Use the Bitwarden mobile app to quickly save logins and auto-fill forms. Download from the App Store or Google Play. +{{/TitleContactUsTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs new file mode 100644 index 0000000000..9c4b2406d4 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs @@ -0,0 +1,920 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    An administrator from {{OrganizationName}} will approve you + before you can share passwords. While you wait for approval, get + started with Bitwarden Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Devices Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Take your passwords with you anywhere.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs new file mode 100644 index 0000000000..38f53e7755 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs @@ -0,0 +1,19 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +A {{OrganizationName}} administrator will approve you before you can share passwords. +While you wait for approval, get started with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Download Bitwarden on all devices: +Take your passwords with you anywhere. (https://www.bitwarden.com/download) + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs new file mode 100644 index 0000000000..d0a4e7e0a4 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs @@ -0,0 +1,919 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    Follow these simple steps to get up and running with Bitwarden + Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Devices Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Take your passwords with you anywhere.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs new file mode 100644 index 0000000000..f698e79ca7 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs @@ -0,0 +1,18 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +Follow these simple steps to get up and running with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Download Bitwarden on all devices: +Take your passwords with you anywhere. (https://bitwarden.com/help/auto-fill-browser/) + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs new file mode 100644 index 0000000000..439fed4b0a --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs @@ -0,0 +1,920 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    An administrator from {{OrganizationName}} will need to confirm + you before you can share passwords. Get started with Bitwarden + Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Autofill Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + + + + + +
    + + + +
    + +
    Fill your passwords securely with one click.
    + +
    + +
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs new file mode 100644 index 0000000000..3808cc818d --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs @@ -0,0 +1,20 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +A {{OrganizationName}} administrator will approve you before you can share passwords. +Get started with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Try Bitwarden autofill: +Fill your passwords securely with one click. (https://bitwarden.com/help/auto-fill-browser/) + + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Mjml/.mjmlconfig b/src/Core/MailTemplates/Mjml/.mjmlconfig index 92734a5f71..a71e3b5ee9 100644 --- a/src/Core/MailTemplates/Mjml/.mjmlconfig +++ b/src/Core/MailTemplates/Mjml/.mjmlconfig @@ -1,7 +1,9 @@ { "packages": [ "components/mj-bw-hero", + "components/mj-bw-simple-hero", "components/mj-bw-icon-row", - "components/mj-bw-learn-more-footer" + "components/mj-bw-learn-more-footer", + "emails/AdminConsole/components/mj-bw-inviter-info" ] } diff --git a/src/Core/MailTemplates/Mjml/README.md b/src/Core/MailTemplates/Mjml/README.md index b9041c94f6..fabb393ee0 100644 --- a/src/Core/MailTemplates/Mjml/README.md +++ b/src/Core/MailTemplates/Mjml/README.md @@ -1,16 +1,15 @@ -# MJML email templating +# `MJML` email templating -This directory contains MJML templates for emails. MJML is a markup language designed to reduce the pain of coding responsive email templates. Component based development features in MJML improve code quality and reusability. +This directory contains `MJML` templates for emails. `MJML` is a markup language designed to reduce the pain of coding responsive email templates. Component-based development features in `MJML` improve code quality and reusability. -MJML stands for MailJet Markup Language. +> [!TIP] +> `MJML` stands for MailJet Markup Language. ## Implementation considerations -These `MJML` templates are compiled into HTML which will then be further consumed by our Handlebars mail service. We can continue to use this service to assign values from our View Models. This leverages the existing infrastructure. It also means we can continue to use the double brace (`{{}}`) syntax within MJML since Handlebars can be used to assign values to those `{{variables}}`. +`MJML` templates are compiled into `HTML`, and those outputs are then consumed by Handlebars to render the final email for delivery. It builds on top of our existing infrastructure and means we can continue to use the double brace (`{{}}`) syntax within `MJML`, since Handlebars will assign values to those `{{variables}}`. -There is no change on how we interact with our view models. - -There is an added step where we compile `*.mjml` to `*.html.hbs`. `*.html.hbs` is the format we use so the handlebars service can apply the variables. This build pipeline process is in progress and may need to be manually done at times. +To do this, there is an added step where we compile `*.mjml` to `*.html.hbs`. `*.html.hbs` is the format we use so the Handlebars service can apply the variables. This build pipeline process is in progress and may need to be manually done at times. ### `*.txt.hbs` @@ -37,45 +36,50 @@ npm run build:minify npm run prettier ``` -## Development +## Development process -MJML supports components and you can create your own components by adding them to `.mjmlconfig`. Components are simple JavaScript that return MJML markup based on the attributes assigned, see components/mj-bw-hero.js. The markup is not a proper object, but contained in a string. +`MJML` supports components and you can create your own components by adding them to `.mjmlconfig`. Components are simple JavaScript that return `MJML` markup based on the attributes assigned, see components/mj-bw-hero.js. The markup is not a proper object, but contained in a string. -When using MJML templating you can use the above [commands](#building-mjml-files) to compile the template and view it in a web browser. +When using `MJML` templating you can use the above [commands](#building-mjml-files) to compile the template and view it in a web browser. -Not all MJML tags have the same attributes, it is highly recommended to review the documentation on the official MJML website to understand the usages of each of the tags. +Not all `MJML` tags have the same attributes, it is highly recommended to review the documentation on the official MJML website to understand the usages of each of the tags. -### Recommended development - IMailService +### Developing the mail template -#### Mjml email template development +1. Create `cool-email.mjml` in appropriate team directory. +2. Run `npm run build:watch`. +3. View compiled `HTML` output in a web browser. +4. Iterate through your development. While running `build:watch` you should be able to refresh the browser page after the `mjml/js` recompile to see the changes. -1. create `cool-email.mjml` in appropriate team directory -2. run `npm run build:watch` -3. view compiled `HTML` output in a web browser -4. iterate -> while `build:watch`'ing you should be able to refresh the browser page after the mjml/js re-compile to see the changes +### Testing the mail template with `IMailer` -#### Testing with `IMailService` +After the email is developed in the [initial step](#developing-the-mail-template), we need to make sure that the email `{{variables}}` are populated properly by Handlebars. We can do this by running it through an `IMailer` implementation. The `IMailer`, documented [here](../../Platform/Mail/README.md#step-3-create-handlebars-templates), requires that the ViewModel, the `.html.hbs` `MJML` build artifact, and `.text.hbs` files be in the same directory. -After the email is developed from the [initial step](#mjml-email-template-development) make sure the email `{{variables}}` are populated properly by running it through an `IMailService` implementation. +1. Run `npm run build:hbs`. +2. Copy built `*.html.hbs` files from the build directory to the directory that the `IMailer` expects. All files in the `Core/MailTemplates/Mjml/out` directory should be copied to the `/src/Core/MailTemplates/Mjml` directory, ensuring that the files are in the same directory as the corresponding ViewModels. If a shared component is modified it is important to copy and overwrite all files in that directory to capture changes in the `*.html.hbs` files. +3. Run code that will send the email. -1. run `npm run build:hbs` -2. copy built `*.html.hbs` files from the build directory to a location the mail service can consume them - 1. all files in the `Core/MailTemplates/Mjml/out` directory can be copied to the `src/Core/MailTemplates/Handlebars/MJML` directory. If a shared component is modified it is important to copy and overwrite all files in that directory to capture - changes in the `*.html.hbs`. -3. run code that will send the email +The minified `html.hbs` artifacts are deliverables and must be placed into the correct `/src/Core/MailTemplates/Mjml` directories in order to be used by `IMailer` implementations, see step 2 above. + +### Testing the mail template with `IMailService` + +> [!WARNING] +> The `IMailService` has been deprecated. The [IMailer](#testing-the-mail-template-with-imailer) should be used instead. + +After the email is developed from the [initial step](#developing-the-mail-template), make sure the email `{{variables}}` are populated properly by running it through an `IMailService` implementation. + +1. Run `npm run build:hbs` +2. Copy built `*.html.hbs` files from the build directory to a location the mail service can consume them. + 1. All files in the `Core/MailTemplates/Mjml/out` directory should be copied to the `src/Core/MailTemplates/Handlebars/MJML` directory. If a shared component is modified it is important to copy and overwrite all files in that directory to capture changes in the `*.html.hbs`. +3. Run code that will send the email. The minified `html.hbs` artifacts are deliverables and must be placed into the correct `src/Core/MailTemplates/Handlebars/` directories in order to be used by `IMailService` implementations, see 2.1 above. -### Recommended development - IMailer - -TBD - PM-26475 - ### Custom tags There is currently a `mj-bw-hero` tag you can use within your `*.mjml` templates. This is a good example of how to create a component that takes in attribute values allowing us to be more DRY in our development of emails. Since the attribute's input is a string we are able to define whatever we need into the component, in this case `mj-bw-hero`. -In order to view the custom component you have written you will need to include it in the `.mjmlconfig` and reference it in an `mjml` template file. - +In order to view the custom component you have written you will need to include it in the `.mjmlconfig` and reference it in a `.mjml` template file. ```html ``` -Attributes in Custom Components are defined by the developer. They can be required or optional depending on implementation. See the official MJML documentation for more information. - +Attributes in custom components are defined by the developer. They can be required or optional depending on implementation. See the official `MJML` [documentation](https://documentation.mjml.io/#components) for more information. ```js static allowedAttributes = { "img-src": "string", // REQUIRED: Source for the image displayed in the right-hand side of the blue header area @@ -108,7 +111,7 @@ Custom components, such as `mj-bw-hero`, must be defined in the `.mjmlconfig` in ### `mj-include` -You are also able to reference other more static MJML templates in your MJML file simply by referencing the file within the MJML template. +You are also able to reference other more static `MJML` templates in your `MJML` file simply by referencing the file within the `MJML` template. ```html @@ -118,6 +121,6 @@ You are also able to reference other more static MJML templates in your MJML fil ``` #### `head.mjml` -Currently we include the `head.mjml` file in all MJML templates as it contains shared styling and formatting that ensures consistency across all email implementations. +Currently we include the `head.mjml` file in all `MJML` templates as it contains shared styling and formatting that ensures consistency across all email implementations. In the future we may deviate from this practice to support different layouts. At that time we will modify the docs with direction. diff --git a/src/Core/MailTemplates/Mjml/build.js b/src/Core/MailTemplates/Mjml/build.js index db8a7fe433..4e3eaef449 100644 --- a/src/Core/MailTemplates/Mjml/build.js +++ b/src/Core/MailTemplates/Mjml/build.js @@ -41,8 +41,10 @@ if (!fs.existsSync(config.outputDir)) { } } -// Find all MJML files with absolute path -const mjmlFiles = glob.sync(`${config.inputDir}/**/*.mjml`); +// Find all MJML files with absolute paths, excluding components directories +const mjmlFiles = glob.sync(`${config.inputDir}/**/*.mjml`, { + ignore: ['**/components/**'] +}); console.log(`\n[INFO] Found ${mjmlFiles.length} MJML file(s) to compile...`); diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js index f7f402c96e..d0ccde5513 100644 --- a/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js @@ -1,4 +1,12 @@ const { BodyComponent } = require("mjml-core"); + +const BODY_TEXT_STYLES = ` + font-family="Roboto, 'Helvetica Neue', Helvetica, Arial, sans-serif" + font-size="16px" + font-weight="400" + line-height="24px" +`; + class MjBwIconRow extends BodyComponent { static dependencies = { "mj-column": ["mj-bw-icon-row"], @@ -18,16 +26,16 @@ class MjBwIconRow extends BodyComponent { static defaultAttributes = {}; - componentHeadStyle = (breakpoint) => { + headStyle = (breakpoint) => { return ` - @media only screen and (max-width:${breakpoint}): { - ".mj-bw-icon-row-text": { - padding-left: "5px !important", - line-height: "20px", - }, - ".mj-bw-icon-row": { - padding: "10px 15px", - width: "fit-content !important", + @media only screen and (max-width:${breakpoint}) { + .mj-bw-icon-row-text { + padding-left: 5px !important; + line-height: 20px; + } + .mj-bw-icon-row { + padding: 10px 15px; + width: fit-content !important; } } `; @@ -36,30 +44,35 @@ class MjBwIconRow extends BodyComponent { render() { const headAnchorElement = this.getAttribute("head-url-text") && this.getAttribute("head-url") - ? ` - ${this.getAttribute("head-url-text")} - - External Link Icon - - ` + ? ` + + + ${this.getAttribute("head-url-text")} + + External Link Icon + + + ` : ""; const footAnchorElement = this.getAttribute("foot-url-text") && this.getAttribute("foot-url") - ? ` - ${this.getAttribute("foot-url-text")} - - External Link Icon - - ` + ? ` + + ${this.getAttribute("foot-url-text")} + + External Link Icon + + + ` : ""; return this.renderMJML( @@ -76,19 +89,11 @@ class MjBwIconRow extends BodyComponent { /> - - ` + - headAnchorElement + - ` - - + ${headAnchorElement} + ${this.getAttribute("text")} - - ` + - footAnchorElement + - ` - + ${footAnchorElement} diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js b/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js new file mode 100644 index 0000000000..e7364e34b0 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js @@ -0,0 +1,40 @@ +const { BodyComponent } = require("mjml-core"); + +class MjBwSimpleHero extends BodyComponent { + static dependencies = { + // Tell the validator which tags are allowed as our component's parent + "mj-column": ["mj-bw-simple-hero"], + "mj-wrapper": ["mj-bw-simple-hero"], + // Tell the validator which tags are allowed as our component's children + "mj-bw-simple-hero": [], + }; + + static allowedAttributes = {}; + + static defaultAttributes = {}; + + render() { + return this.renderMJML( + ` + + + + + + `, + ); + } +} + +module.exports = MjBwSimpleHero; diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml new file mode 100644 index 0000000000..24f85af31c --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + As a member of {{OrganizationName}}: + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml new file mode 100644 index 0000000000..2e48e82f84 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + As a member of {{OrganizationName}}: + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js new file mode 100644 index 0000000000..e9d392f570 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js @@ -0,0 +1,35 @@ +const { BodyComponent } = require("mjml-core"); + +class MjBwInviterInfo extends BodyComponent { + + static dependencies = { + "mj-column": ["mj-bw-inviter-info"], + "mj-wrapper": ["mj-bw-inviter-info"], + "mj-bw-inviter-info": [], + }; + + static allowedAttributes = { + "expiration-date": "string", // REQUIRED: Date to display + "email-address": "string", // Optional: Email address to display + }; + + render() { + const emailAddressText = this.getAttribute("email-address") + ? `This invitation was sent by ${this.getAttribute("email-address")} and expires ` + : "This invitation expires "; + + return this.renderMJML( + ` + + + + ${emailAddressText + this.getAttribute("expiration-date")} + + + + ` + ); + } +} + +module.exports = MjBwInviterInfo; diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml new file mode 100644 index 0000000000..8e990dc924 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml @@ -0,0 +1,38 @@ + + + + + Download Bitwarden on all devices + + + Already using the browser extension? + Download the Bitwarden mobile app from the + App Store + or Google Play + to quickly save logins and autofill forms on the go. + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml index 86de49016d..7c81a700f2 100644 --- a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml @@ -9,7 +9,7 @@ diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml similarity index 97% rename from src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml rename to src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml index e071cd26cc..4fc9bc466a 100644 --- a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml @@ -9,7 +9,7 @@ diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml index 39f18fce66..7b8a03dc7e 100644 --- a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml @@ -9,7 +9,7 @@ diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml index d3d4eb9891..660bbf0b45 100644 --- a/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml @@ -1,7 +1,13 @@ - + + .send-bubble { + padding-left: 20px; + padding-right: 20px; + width: 90% !important; + } + @@ -18,18 +24,17 @@ Your verification code is: - {{Token}} + + {{Token}} + - This code expires in {{Expiry}} minutes. After that, you'll need to - verify your email again. + This code expires in {{Expiry}} minutes. After that, you'll need + to verify your email again. - + + + + + + + + + + + + + + + + + Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually + at {{BaseAnnualRenewalPrice}} + tax. + + + As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + + + Questions? Contact + support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml new file mode 100644 index 0000000000..dcf193875a --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + + + Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually. + + + Questions? Contact support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml new file mode 100644 index 0000000000..a460442a7c --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. + + + As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + + + Questions? Contact + support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml b/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml deleted file mode 100644 index c50a5d1292..0000000000 --- a/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - - - - - - - - Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc. - - - - - - - - - diff --git a/src/Core/MailTemplates/README.md b/src/Core/MailTemplates/README.md index bd42b2a10f..312821afd3 100644 --- a/src/Core/MailTemplates/README.md +++ b/src/Core/MailTemplates/README.md @@ -75,4 +75,14 @@ The `IMailService` automatically uses both versions when sending emails: - Test plain text templates to ensure they're readable and convey the same message ## `*.mjml` -This is a templating language we use to increase efficiency when creating email content. See the readme within the `./mjml` directory for more comprehensive information. +This is a templating language we use to increase efficiency when creating email content. See the `MJML` [documentation](./Mjml/README.md) for more details. + +# Managing email assets + +We host assets that are included in emails at `assets.bitwarden.com`, at the `/email/v1` path. This corresponds to a static file storage container that is managed by our SRE team. For example: https://assets.bitwarden.com/email/v1/mail-github.png. This is the URL for all assets for emails sent from any environment. + +## Adding an asset + +When you are creating an email that needs a new asset, you should first check to see if that asset already exists. The easiest way to do this is check at the corresponding `https://assets.bitwarden.com/email/v1/` URL (e.g. https://assets.bitwarden.com/email/v1/my_new_image.png). + +If the asset you are adding is not there, enter a ticket for the SRE team to add the asset to the email asset container. The preferred format for assets is a `.png` file, and the file(s) should be attached to the ticket. \ No newline at end of file diff --git a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs index 7473738ffc..aa49c25d36 100644 --- a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs +++ b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs @@ -299,7 +299,7 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate ? organization.SmServiceAccounts - plan.SecretsManager.BaseServiceAccount : 0, PurchasedAdditionalStorage = organization.MaxStorageGb.HasValue - ? organization.MaxStorageGb.Value - (plan.PasswordManager.BaseStorageGb ?? 0) : + ? organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb : 0 }; } diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index be514cb39f..68a060b4a8 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -1,4 +1,5 @@ using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Stripe; #nullable enable @@ -150,7 +151,7 @@ public class SubscriptionInfo } Quantity = (int)item.Quantity; - SponsoredSubscriptionItem = item.Plan != null && Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); + SponsoredSubscriptionItem = item.Plan != null && SponsoredPlans.All.Any(p => p.StripePlanId == item.Plan.Id); } public bool AddonSubscriptionItem { get; set; } diff --git a/src/Core/Models/Data/UserKdfInformation.cs b/src/Core/Models/Data/UserKdfInformation.cs index 14f525bb82..0e5696e581 100644 --- a/src/Core/Models/Data/UserKdfInformation.cs +++ b/src/Core/Models/Data/UserKdfInformation.cs @@ -4,8 +4,8 @@ namespace Bit.Core.Models.Data; public class UserKdfInformation { - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } + public required KdfType Kdf { get; set; } + public required int KdfIterations { get; set; } public int? KdfMemory { get; set; } public int? KdfParallelism { get; set; } } diff --git a/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs b/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs new file mode 100644 index 0000000000..b852d24ec9 --- /dev/null +++ b/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Models.Mail.Auth; + +public class OrganizationWelcomeEmailViewModel : BaseMailModel +{ + public required string OrganizationName { get; set; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs new file mode 100644 index 0000000000..e3aff02f5d --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs @@ -0,0 +1,16 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; + +public class Families2019RenewalMailView : BaseMailView +{ + public required string BaseMonthlyRenewalPrice { get; set; } + public required string BaseAnnualRenewalPrice { get; set; } + public required string DiscountedAnnualRenewalPrice { get; set; } + public required string DiscountAmount { get; set; } +} + +public class Families2019RenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Families renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs new file mode 100644 index 0000000000..227613999b --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -0,0 +1,584 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + +
    Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually + at {{BaseAnnualRenewalPrice}} + tax.
    + +
    + +
    As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
    + +
    + +
    Questions? Contact + support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs new file mode 100644 index 0000000000..88d64f9acf --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -0,0 +1,7 @@ +Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually +at {{BaseAnnualRenewalPrice}} + tax. + +As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs new file mode 100644 index 0000000000..eb7bef4322 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs @@ -0,0 +1,13 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; + +public class Families2020RenewalMailView : BaseMailView +{ + public required string MonthlyRenewalPrice { get; set; } +} + +public class Families2020RenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Families renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs new file mode 100644 index 0000000000..ac6b80993c --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs @@ -0,0 +1,619 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Your Bitwarden Families renewal is updating +

    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + +
    Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    Questions? Contact support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs new file mode 100644 index 0000000000..002a48cf10 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs @@ -0,0 +1,3 @@ +Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs new file mode 100644 index 0000000000..e231a44467 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs @@ -0,0 +1,15 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Premium; + +public class PremiumRenewalMailView : BaseMailView +{ + public required string BaseMonthlyRenewalPrice { get; set; } + public required string DiscountedMonthlyRenewalPrice { get; set; } + public required string DiscountAmount { get; set; } +} + +public class PremiumRenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Premium renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs new file mode 100644 index 0000000000..a6b2fda0f7 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -0,0 +1,583 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + +
    Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    Questions? Contact + support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs new file mode 100644 index 0000000000..41300d0f96 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -0,0 +1,6 @@ +Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. + +As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs deleted file mode 100644 index aeca436dbb..0000000000 --- a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Bit.Core.Platform.Mail.Mailer; - -namespace Bit.Core.Models.Mail.UpdatedInvoiceIncoming; - -public class UpdatedInvoiceUpcomingView : BaseMailView; - -public class UpdatedInvoiceUpcomingMail : BaseMail -{ - public override string Subject { get => "Your Subscription Will Renew Soon"; } -} diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs deleted file mode 100644 index a044171fe5..0000000000 --- a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs +++ /dev/null @@ -1,30 +0,0 @@ -
    Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc.

    © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

    Always confirm you are on a trusted Bitwarden domain before logging in:
    bitwarden.com | Learn why we include this

    \ No newline at end of file diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs deleted file mode 100644 index a2db92bac2..0000000000 --- a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs +++ /dev/null @@ -1,3 +0,0 @@ -{{#>BasicTextLayout}} - Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc. -{{/BasicTextLayout}} diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index a622b98e05..ec39c495aa 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -1,4 +1,5 @@ -using Bit.Core.Enums; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; using Bit.Core.NotificationCenter.Enums; namespace Bit.Core.Models; @@ -103,3 +104,9 @@ public class LogOutPushNotification public Guid UserId { get; set; } public PushNotificationLogOutReason? Reason { get; set; } } + +public class SyncPolicyPushNotification +{ + public Guid OrganizationId { get; set; } + public required Policy Policy { get; set; } +} diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index 8cfd0a8df1..b502cc6e4e 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -12,8 +12,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Authorization; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -43,6 +45,9 @@ using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using V1_RevokeUsersCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; +using V2_RevokeUsersCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + namespace Bit.Core.OrganizationFeatures; public static class OrganizationServiceCollectionExtensions @@ -86,6 +91,7 @@ public static class OrganizationServiceCollectionExtensions private static void AddOrganizationUpdateCommands(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); } private static void AddOrganizationEnableCommands(this IServiceCollection services) => @@ -130,14 +136,20 @@ public static class OrganizationServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); - services.AddScoped(); services.AddScoped(); services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); + + services.AddScoped(); + + services.AddScoped(); + services.AddScoped(); } private static void AddOrganizationApiKeyCommandsQueries(this IServiceCollection services) @@ -192,6 +204,7 @@ public static class OrganizationServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs index 2756f8930b..566c723692 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -7,7 +8,6 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; @@ -54,10 +54,9 @@ public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand foreach (var selfHostedSponsorship in sponsorshipsData) { - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductTierType; + var requiredSponsoringProductType = SponsoredPlans.Get(selfHostedSponsorship.PlanSponsorshipType).SponsoringProductTierType; var sponsoringOrgProductTier = sponsoringOrg.PlanType.GetProductTier(); - if (requiredSponsoringProductType == null - || sponsoringOrgProductTier != requiredSponsoringProductType.Value) + if (sponsoringOrgProductTier != requiredSponsoringProductType) { continue; // prevent unsupported sponsorships } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs index a54106481c..6d60f05b2a 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; @@ -13,9 +13,9 @@ public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand { private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; - public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) + public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IStripePaymentService paymentService) { _organizationSponsorshipRepository = organizationSponsorshipRepository; _organizationRepository = organizationRepository; @@ -50,11 +50,10 @@ public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand } // Check org to sponsor's product type - var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductTierType; + var requiredSponsoredProductType = SponsoredPlans.Get(sponsorship.PlanSponsorshipType.Value).SponsoredProductTierType; var sponsoredOrganizationProductTier = sponsoredOrganization.PlanType.GetProductTier(); - if (requiredSponsoredProductType == null || - sponsoredOrganizationProductTier != requiredSponsoredProductType.Value) + if (sponsoredOrganizationProductTier != requiredSponsoredProductType) { throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs index dcda77acea..4b983317c9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs @@ -3,6 +3,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -13,14 +15,14 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnte public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; public ValidateSponsorshipCommand( IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) { @@ -95,7 +97,7 @@ public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSpo return false; } - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); + var sponsoredPlan = SponsoredPlans.Get(existingSponsorship.PlanSponsorshipType.Value); var sponsoringOrganization = await _organizationRepository .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index a729937fad..ab4b17d215 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -7,7 +8,6 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; @@ -34,11 +34,10 @@ public class CreateSponsorshipCommand( throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); } - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductTierType; + var requiredSponsoringProductType = SponsoredPlans.Get(sponsorshipType).SponsoringProductTierType; var sponsoringOrgProductTier = sponsoringOrganization.PlanType.GetProductTier(); - if (requiredSponsoringProductType == null || - sponsoringOrgProductTier != requiredSponsoringProductType.Value) + if (sponsoringOrgProductTier != requiredSponsoringProductType) { throw new BadRequestException("Specified Organization cannot sponsor other organizations."); } diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs index a0ce7c03b9..25b84fe989 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; @@ -12,13 +13,13 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class AddSecretsManagerSubscriptionCommand : IAddSecretsManagerSubscriptionCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationService _organizationService; private readonly IProviderRepository _providerRepository; private readonly IPricingClient _pricingClient; public AddSecretsManagerSubscriptionCommand( - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationService organizationService, IProviderRepository providerRepository, IPricingClient pricingClient) diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs index d4e1b3cd8d..baf2616a53 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -18,7 +19,7 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubscriptionCommand { private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; private readonly IServiceAccountRepository _serviceAccountRepository; @@ -29,7 +30,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs public UpdateSecretsManagerSubscriptionCommand( IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger, IServiceAccountRepository serviceAccountRepository, diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs index 2b39e6cca6..092ee0f46e 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs @@ -11,6 +11,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -26,7 +27,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IOrganizationConnectionRepository _organizationConnectionRepository; @@ -41,7 +42,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand IOrganizationUserRepository organizationUserRepository, ICollectionRepository collectionRepository, IGroupRepository groupRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, ISsoConfigRepository ssoConfigRepository, IOrganizationConnectionRepository organizationConnectionRepository, @@ -254,9 +255,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand organization.UseApi = newPlan.HasApi; organization.SelfHost = newPlan.HasSelfHost; organization.UsePolicies = newPlan.HasPolicies; - organization.MaxStorageGb = !newPlan.PasswordManager.BaseStorageGb.HasValue - ? (short?)null - : (short)(newPlan.PasswordManager.BaseStorageGb.Value + upgrade.AdditionalStorageGb); + organization.MaxStorageGb = (short)(newPlan.PasswordManager.BaseStorageGb + upgrade.AdditionalStorageGb); organization.UseGroups = newPlan.HasGroups; organization.UseDirectory = newPlan.HasDirectory; organization.UseEvents = newPlan.HasEvents; diff --git a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs b/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs deleted file mode 100644 index 6b76bc35f0..0000000000 --- a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs +++ /dev/null @@ -1,95 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text; -using Azure.Storage.Blobs; -using Azure.Storage.Blobs.Models; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -public class AzurePhishingDomainStorageService -{ - private const string _containerName = "phishingdomains"; - private const string _domainsFileName = "domains.txt"; - private const string _checksumFileName = "checksum.txt"; - - private readonly BlobServiceClient _blobServiceClient; - private readonly ILogger _logger; - private BlobContainerClient _containerClient; - - public AzurePhishingDomainStorageService( - GlobalSettings globalSettings, - ILogger logger) - { - _blobServiceClient = new BlobServiceClient(globalSettings.Storage.ConnectionString); - _logger = logger; - } - - public async Task> GetDomainsAsync() - { - await InitAsync(); - - var blobClient = _containerClient.GetBlobClient(_domainsFileName); - if (!await blobClient.ExistsAsync()) - { - return []; - } - - var response = await blobClient.DownloadAsync(); - using var streamReader = new StreamReader(response.Value.Content); - var content = await streamReader.ReadToEndAsync(); - - return [.. content - .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) - .Select(line => line.Trim()) - .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith('#'))]; - } - - public async Task GetChecksumAsync() - { - await InitAsync(); - - var blobClient = _containerClient.GetBlobClient(_checksumFileName); - if (!await blobClient.ExistsAsync()) - { - return string.Empty; - } - - var response = await blobClient.DownloadAsync(); - using var streamReader = new StreamReader(response.Value.Content); - return (await streamReader.ReadToEndAsync()).Trim(); - } - - public async Task UpdateDomainsAsync(IEnumerable domains, string checksum) - { - await InitAsync(); - - var domainsContent = string.Join(Environment.NewLine, domains); - var domainsStream = new MemoryStream(Encoding.UTF8.GetBytes(domainsContent)); - var domainsBlobClient = _containerClient.GetBlobClient(_domainsFileName); - - await domainsBlobClient.UploadAsync(domainsStream, new BlobUploadOptions - { - HttpHeaders = new BlobHttpHeaders { ContentType = "text/plain" } - }, CancellationToken.None); - - var checksumStream = new MemoryStream(Encoding.UTF8.GetBytes(checksum)); - var checksumBlobClient = _containerClient.GetBlobClient(_checksumFileName); - - await checksumBlobClient.UploadAsync(checksumStream, new BlobUploadOptions - { - HttpHeaders = new BlobHttpHeaders { ContentType = "text/plain" } - }, CancellationToken.None); - } - - private async Task InitAsync() - { - if (_containerClient is null) - { - _containerClient = _blobServiceClient.GetBlobContainerClient(_containerName); - await _containerClient.CreateIfNotExistsAsync(); - } - } -} diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs deleted file mode 100644 index 420948e310..0000000000 --- a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs +++ /dev/null @@ -1,100 +0,0 @@ -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -/// -/// Implementation of ICloudPhishingDomainQuery for cloud environments -/// that directly calls the external phishing domain source -/// -public class CloudPhishingDomainDirectQuery : ICloudPhishingDomainQuery -{ - private readonly IGlobalSettings _globalSettings; - private readonly IHttpClientFactory _httpClientFactory; - private readonly ILogger _logger; - - public CloudPhishingDomainDirectQuery( - IGlobalSettings globalSettings, - IHttpClientFactory httpClientFactory, - ILogger logger) - { - _globalSettings = globalSettings; - _httpClientFactory = httpClientFactory; - _logger = logger; - } - - public async Task> GetPhishingDomainsAsync() - { - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) - { - throw new InvalidOperationException("Phishing domain update URL is not configured."); - } - - var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); - var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.UpdateUrl); - response.EnsureSuccessStatusCode(); - - var content = await response.Content.ReadAsStringAsync(); - return ParseDomains(content); - } - - /// - /// Gets the SHA256 checksum of the remote phishing domains list - /// - /// The SHA256 checksum as a lowercase hex string - public async Task GetRemoteChecksumAsync() - { - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.ChecksumUrl)) - { - _logger.LogWarning("Phishing domain checksum URL is not configured."); - return string.Empty; - } - - try - { - var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); - var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.ChecksumUrl); - response.EnsureSuccessStatusCode(); - - var content = await response.Content.ReadAsStringAsync(); - return ParseChecksumResponse(content); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error retrieving phishing domain checksum from {Url}", - _globalSettings.PhishingDomain.ChecksumUrl); - return string.Empty; - } - } - - /// - /// Parses a checksum response in the format "hash *filename" - /// - private static string ParseChecksumResponse(string checksumContent) - { - if (string.IsNullOrWhiteSpace(checksumContent)) - { - return string.Empty; - } - - // Format is typically "hash *filename" - var parts = checksumContent.Split(' ', 2); - - return parts.Length > 0 ? parts[0].Trim() : string.Empty; - } - - private static List ParseDomains(string content) - { - if (string.IsNullOrWhiteSpace(content)) - { - return []; - } - - return content - .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) - .Select(line => line.Trim()) - .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#")) - .ToList(); - } -} diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs deleted file mode 100644 index 6b0027062c..0000000000 --- a/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs +++ /dev/null @@ -1,69 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Services; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -/// -/// Implementation of ICloudPhishingDomainQuery for self-hosted environments -/// that relays the request to the Bitwarden cloud API -/// -public class CloudPhishingDomainRelayQuery : BaseIdentityClientService, ICloudPhishingDomainQuery -{ - private readonly IGlobalSettings _globalSettings; - - public CloudPhishingDomainRelayQuery( - IHttpClientFactory httpFactory, - IGlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.Installation.ApiUri, - globalSettings.Installation.IdentityUri, - "api.licensing", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) - { - _globalSettings = globalSettings; - } - - public async Task> GetPhishingDomainsAsync() - { - if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) - { - throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); - } - - var result = await SendAsync(HttpMethod.Get, "phishing-domains", null, true); - return result?.ToList() ?? new List(); - } - - /// - /// Gets the SHA256 checksum of the remote phishing domains list - /// - /// The SHA256 checksum as a lowercase hex string - public async Task GetRemoteChecksumAsync() - { - if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) - { - throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); - } - - try - { - // For self-hosted environments, we get the checksum from the Bitwarden cloud API - var result = await SendAsync(HttpMethod.Get, "phishing-domains/checksum", null, true); - return result ?? string.Empty; - } - catch (Exception ex) - { - _logger.LogError(ex, "Error retrieving phishing domain checksum from Bitwarden cloud API"); - return string.Empty; - } - } -} diff --git a/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs b/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs deleted file mode 100644 index dac91747f7..0000000000 --- a/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.PhishingDomainFeatures.Interfaces; - -public interface ICloudPhishingDomainQuery -{ - Task> GetPhishingDomainsAsync(); - Task GetRemoteChecksumAsync(); -} diff --git a/src/Core/Platform/Mail/HandlebarsMailService.cs b/src/Core/Platform/Mail/HandlebarsMailService.cs index 072fe79e71..d57ca400fd 100644 --- a/src/Core/Platform/Mail/HandlebarsMailService.cs +++ b/src/Core/Platform/Mail/HandlebarsMailService.cs @@ -78,7 +78,7 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendRegistrationVerificationEmailAsync(string email, string token) + public async Task SendRegistrationVerificationEmailAsync(string email, string token, string? fromMarketing) { var message = CreateDefaultMessage("Verify Your Email", email); var model = new RegisterVerifyEmail @@ -86,7 +86,8 @@ public class HandlebarsMailService : IMailService Token = WebUtility.UrlEncode(token), Email = WebUtility.UrlEncode(email), WebVaultUrl = _globalSettings.BaseServiceUri.Vault, - SiteName = _globalSettings.SiteName + SiteName = _globalSettings.SiteName, + FromMarketing = WebUtility.UrlEncode(fromMarketing), }; await AddMessageContentAsync(message, "Auth.RegistrationVerifyEmail", model); message.MetaData.Add("SendGridBypassListManagement", true); @@ -424,6 +425,8 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + // TODO: DO NOT move to IMailer implementation: PM-27852 + [Obsolete("Use SendIndividualUserWelcomeEmailAsync instead")] public async Task SendWelcomeEmailAsync(User user) { var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); @@ -437,6 +440,50 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendIndividualUserWelcomeEmailAsync(User user) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new BaseMailModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-individual-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new OrganizationWelcomeEmailViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-org-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new OrganizationWelcomeEmailViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(familyOrganizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-family-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendTrialInitiationEmailAsync(string userEmail) { var message = CreateDefaultMessage("Welcome to Bitwarden; 3 steps to get started!", userEmail); diff --git a/src/Core/Platform/Mail/IMailService.cs b/src/Core/Platform/Mail/IMailService.cs index 52fbdb9b6d..e21e1a010f 100644 --- a/src/Core/Platform/Mail/IMailService.cs +++ b/src/Core/Platform/Mail/IMailService.cs @@ -15,9 +15,30 @@ namespace Bit.Core.Services; [Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public interface IMailService { + [Obsolete("Use SendIndividualUserWelcomeEmailAsync instead")] Task SendWelcomeEmailAsync(User user); + /// + /// Email sent to users who have created a new account as an individual user. + /// + /// The new User + /// Task + Task SendIndividualUserWelcomeEmailAsync(User user); + /// + /// Email sent to users who have been confirmed to an organization. + /// + /// The User + /// The Organization user is being added to + /// Task + Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName); + /// + /// Email sent to users who have been confirmed to a free or families organization. + /// + /// The User + /// The Families Organization user is being added to + /// Task + Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName); Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); - Task SendRegistrationVerificationEmailAsync(string email, string token); + Task SendRegistrationVerificationEmailAsync(string email, string token, string? fromMarketing); Task SendTrialInitiationSignupEmailAsync( bool isExistingUser, string email, diff --git a/src/Core/Platform/Mail/NoopMailService.cs b/src/Core/Platform/Mail/NoopMailService.cs index 45a860a155..7de48e4619 100644 --- a/src/Core/Platform/Mail/NoopMailService.cs +++ b/src/Core/Platform/Mail/NoopMailService.cs @@ -26,7 +26,7 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendRegistrationVerificationEmailAsync(string email, string hint) + public Task SendRegistrationVerificationEmailAsync(string email, string hint, string? fromMarketing) { return Task.FromResult(0); } @@ -114,6 +114,20 @@ public class NoopMailService : IMailService return Task.FromResult(0); } + public Task SendIndividualUserWelcomeEmailAsync(User user) + { + return Task.FromResult(0); + } + + public Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName) + { + return Task.FromResult(0); + } + + public Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName) + { + return Task.FromResult(0); + } public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) { return Task.FromResult(0); diff --git a/src/Core/Platform/Mail/README.md b/src/Core/Platform/Mail/README.md index b5caca62be..7a3b6b87c5 100644 --- a/src/Core/Platform/Mail/README.md +++ b/src/Core/Platform/Mail/README.md @@ -1,9 +1,14 @@ # Mail Services ## `MailService` -The `MailService` and its implementation in `HandlebarsMailService` has been deprecated in favor of the `Mailer` implementation. +> [!WARNING] +> The `MailService` and its implementation in `HandlebarsMailService` has been deprecated in favor of the `Mailer` implementation. -New emails should be implemented using [MJML](../../MailTemplates/README.md) and the `Mailer`. +The `MailService` class manages **all** emails, and has multiple responsibilities, including formatting, email building (instantiation of ViewModels from variables), and deciding if a mail request should be enqueued or sent directly. + +The resulting implementation cannot be owned by a single team (since all emails are in a single class), and as a result, anyone can edit any template without the appropriate team being informed. + +To alleviate these issues, all new emails should be implemented using [MJML](../../MailTemplates/README.md) and the `Mailer`. ## `Mailer` @@ -16,20 +21,20 @@ The Mailer system consists of four main components: 1. **IMailer** - Service interface for sending emails 2. **BaseMail** - Abstract base class defining email metadata (recipients, subject, category) -3. **BaseMailView** - Abstract base class for email template view models +3. **BaseMailView** - Abstract base class for email template ViewModels 4. **IMailRenderer** - Internal interface for rendering templates (implemented by `HandlebarMailRenderer`) ### How To Use -1. Define a view model that inherits from `BaseMailView` with properties for template data -2. Create Handlebars templates (`.html.hbs` and `.text.hbs`) as embedded resources, preferably using the MJML pipeline, - `/src/Core/MailTemplates/Mjml`. -3. Define an email class that inherits from `BaseMail` with metadata like subject -4. Use `IMailer.SendEmail()` to render and send the email +1. Define a ViewModel that inherits from `BaseMailView` with properties for template data. +2. Define an email class that inherits from `BaseMail` with metadata like `Subject`. +3. Create Handlebars templates (`.html.hbs` and `.text.hbs`) as embedded resources, preferably using the `MJML` [pipeline](../../MailTemplates/Mjml/README.md#development-process), in + a directory in `/src/Core/MailTemplates/Mjml`. +4. Use `IMailer.SendEmail()` to render and send the email. ### Creating a New Email -#### Step 1: Define the Email & View Model +#### Step 1: Define the ViewModel Create a class that inherits from `BaseMailView`: @@ -43,17 +48,25 @@ public class WelcomeEmailView : BaseMailView public required string UserName { get; init; } public required string ActivationUrl { get; init; } } +``` +#### Step 2: Define the email class + +Create a class that inherits from `BaseMail`: + +```csharp public class WelcomeEmail : BaseMail { public override string Subject => "Welcome to Bitwarden"; } ``` -#### Step 2: Create Handlebars Templates +#### Step 3: Create Handlebars templates -Create two template files as embedded resources next to your view model. **Important**: The file names must be located -directly next to the `ViewClass` and match the name of the view. +Create two template files as embedded resources next to your ViewModel. + +> [!IMPORTANT] +> The files must be located directly next to the `ViewClass` and match the name of the view. **WelcomeEmailView.html.hbs** (HTML version): @@ -87,7 +100,7 @@ Activate your account: {{ ActivationUrl }} ``` -#### Step 3: Send the Email +#### Step 4: Send the email Inject `IMailer` and send the email, this may be done in a service, command or some other application layer. @@ -160,7 +173,7 @@ public class MarketingEmail : BaseMail ### Built-in View Properties -All view models inherit from `BaseMailView`, which provides: +All ViewModels inherit from `BaseMailView`, which provides: - **CurrentYear** - The current UTC year (useful for copyright notices) @@ -176,7 +189,7 @@ Templates must follow this naming convention: - HTML template: `{ViewModelFullName}.html.hbs` - Text template: `{ViewModelFullName}.text.hbs` -For example, if your view model is `Bit.Core.Auth.Models.Mail.VerifyEmailView`, the templates must be: +For example, if your ViewModel is `Bit.Core.Auth.Models.Mail.VerifyEmailView`, the templates must be: - `Bit.Core.Auth.Models.Mail.VerifyEmailView.html.hbs` - `Bit.Core.Auth.Models.Mail.VerifyEmailView.text.hbs` @@ -210,4 +223,4 @@ services.TryAddSingleton(); The mail services support loading the mail template from disk. This is intended to be used by self-hosted customers who want to modify their email appearance. These overrides are not intended to be used during local development, as any changes there would not be reflected in the templates used in a normal deployment configuration. -Any customer using this override has worked with Bitwarden support on an approved implementation and has acknowledged that they are responsible for reacting to any changes made to the templates as a part of the Bitwarden development process. This includes, but is not limited to, changes in Handlebars property names, removal of properties from the `ViewModel` classes, and changes in template names. **Bitwarden is not responsible for maintaining backward compatibility between releases in order to support any overridden emails.** \ No newline at end of file +Any customer using this override has worked with Bitwarden support on an approved implementation and has acknowledged that they are responsible for reacting to any changes made to the templates as a part of the Bitwarden development process. This includes, but is not limited to, changes in Handlebars property names, removal of properties from the ViewModel classes, and changes in template names. **Bitwarden is not responsible for maintaining backward compatibility between releases in order to support any overridden emails.** \ No newline at end of file diff --git a/src/Core/Platform/Push/PushType.cs b/src/Core/Platform/Push/PushType.cs index 93eca86243..9a601ab0d3 100644 --- a/src/Core/Platform/Push/PushType.cs +++ b/src/Core/Platform/Push/PushType.cs @@ -95,5 +95,8 @@ public enum PushType : byte OrganizationBankAccountVerified = 23, [NotificationInfo("@bitwarden/team-billing-dev", typeof(Models.ProviderBankAccountVerifiedPushNotification))] - ProviderBankAccountVerified = 24 + ProviderBankAccountVerified = 24, + + [NotificationInfo("@bitwarden/team-admin-console-dev", typeof(Models.SyncPolicyPushNotification))] + PolicyChanged = 25, } diff --git a/src/Core/Repositories/IOrganizationDomainRepository.cs b/src/Core/Repositories/IOrganizationDomainRepository.cs index d802fe65df..b993cd42fa 100644 --- a/src/Core/Repositories/IOrganizationDomainRepository.cs +++ b/src/Core/Repositories/IOrganizationDomainRepository.cs @@ -17,4 +17,5 @@ public interface IOrganizationDomainRepository : IRepository GetDomainByOrgIdAndDomainNameAsync(Guid orgId, string domainName); Task> GetExpiredOrganizationDomainsAsync(); Task DeleteExpiredAsync(int expirationPeriod); + Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null); } diff --git a/src/Core/Repositories/IPhishingDomainRepository.cs b/src/Core/Repositories/IPhishingDomainRepository.cs deleted file mode 100644 index 2d653b0a43..0000000000 --- a/src/Core/Repositories/IPhishingDomainRepository.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Bit.Core.Repositories; - -public interface IPhishingDomainRepository -{ - Task> GetActivePhishingDomainsAsync(); - Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum); - Task GetCurrentChecksumAsync(); -} diff --git a/src/Core/Repositories/IUserRepository.cs b/src/Core/Repositories/IUserRepository.cs index 22effb4329..47ddb86f8e 100644 --- a/src/Core/Repositories/IUserRepository.cs +++ b/src/Core/Repositories/IUserRepository.cs @@ -1,4 +1,6 @@ -using Bit.Core.Entities; +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; @@ -23,6 +25,7 @@ public interface IUserRepository : IRepository /// Retrieves the data for the requested user IDs and includes an additional property indicating /// whether the user has premium access directly or through an organization. ///
    + [Obsolete("Use GetPremiumAccessByIdsAsync instead. This method will be removed in a future version.")] Task> GetManyWithCalculatedPremiumAsync(IEnumerable ids); /// /// Retrieves the data for the requested user ID and includes additional property indicating @@ -33,8 +36,23 @@ public interface IUserRepository : IRepository /// /// The user ID to retrieve data for. /// User data with calculated premium access; null if nothing is found + [Obsolete("Use GetPremiumAccessAsync instead. This method will be removed in a future version.")] Task GetCalculatedPremiumAsync(Guid userId); /// + /// Retrieves premium access status for multiple users. + /// For internal use - consumers should use IHasPremiumAccessQuery instead. + /// + /// The user IDs to check + /// Collection of UserPremiumAccess objects containing premium status information + Task> GetPremiumAccessByIdsAsync(IEnumerable ids); + /// + /// Retrieves premium access status for a single user. + /// For internal use - consumers should use IHasPremiumAccessQuery instead. + /// + /// The user ID to check + /// UserPremiumAccess object containing premium status information, or null if user not found + Task GetPremiumAccessAsync(Guid userId); + /// /// Sets a new user key and updates all encrypted data. /// Warning: Any user key encrypted data not included will be lost. /// @@ -44,5 +62,17 @@ public interface IUserRepository : IRepository IEnumerable updateDataActions); Task UpdateUserKeyAndEncryptedDataV2Async(User user, IEnumerable updateDataActions); + /// + /// Sets the account cryptographic state to a user in a single transaction. The provided + /// MUST be a V2 encryption state. Passing in a V1 encryption state will throw. + /// Extra actions can be passed in case other user data needs to be updated in the same transaction. + /// + Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null); Task DeleteManyAsync(IEnumerable users); } + +public delegate Task UpdateUserData(Microsoft.Data.SqlClient.SqlConnection? connection = null, + Microsoft.Data.SqlClient.SqlTransaction? transaction = null); diff --git a/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs b/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs deleted file mode 100644 index 2d4ea15b7e..0000000000 --- a/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs +++ /dev/null @@ -1,126 +0,0 @@ -using System.Text.Json; -using Bit.Core.PhishingDomainFeatures; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.Repositories.Implementations; - -public class AzurePhishingDomainRepository : IPhishingDomainRepository -{ - private readonly AzurePhishingDomainStorageService _storageService; - private readonly IDistributedCache _cache; - private readonly ILogger _logger; - private const string _domainsCacheKey = "PhishingDomains_v1"; - private const string _checksumCacheKey = "PhishingDomains_Checksum_v1"; - private static readonly DistributedCacheEntryOptions _cacheOptions = new() - { - AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24), - SlidingExpiration = TimeSpan.FromHours(1) - }; - - public AzurePhishingDomainRepository( - AzurePhishingDomainStorageService storageService, - IDistributedCache cache, - ILogger logger) - { - _storageService = storageService; - _cache = cache; - _logger = logger; - } - - public async Task> GetActivePhishingDomainsAsync() - { - try - { - var cachedDomains = await _cache.GetStringAsync(_domainsCacheKey); - if (!string.IsNullOrEmpty(cachedDomains)) - { - _logger.LogDebug("Retrieved phishing domains from cache"); - return JsonSerializer.Deserialize>(cachedDomains) ?? []; - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to retrieve phishing domains from cache"); - } - - var domains = await _storageService.GetDomainsAsync(); - - try - { - await _cache.SetStringAsync( - _domainsCacheKey, - JsonSerializer.Serialize(domains), - _cacheOptions); - _logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count); - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to store phishing domains in cache"); - } - - return domains; - } - - public async Task GetCurrentChecksumAsync() - { - try - { - var cachedChecksum = await _cache.GetStringAsync(_checksumCacheKey); - if (!string.IsNullOrEmpty(cachedChecksum)) - { - _logger.LogDebug("Retrieved phishing domain checksum from cache"); - return cachedChecksum; - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to retrieve phishing domain checksum from cache"); - } - - var checksum = await _storageService.GetChecksumAsync(); - - try - { - if (!string.IsNullOrEmpty(checksum)) - { - await _cache.SetStringAsync( - _checksumCacheKey, - checksum, - _cacheOptions); - _logger.LogDebug("Stored phishing domain checksum in cache"); - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to store phishing domain checksum in cache"); - } - - return checksum; - } - - public async Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum) - { - var domainsList = domains.ToList(); - await _storageService.UpdateDomainsAsync(domainsList, checksum); - - try - { - await _cache.SetStringAsync( - _domainsCacheKey, - JsonSerializer.Serialize(domainsList), - _cacheOptions); - - await _cache.SetStringAsync( - _checksumCacheKey, - checksum, - _cacheOptions); - - _logger.LogDebug("Updated phishing domains cache after update operation"); - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to update phishing domains in cache"); - } - } -} diff --git a/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs b/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs new file mode 100644 index 0000000000..b6dd1d778d --- /dev/null +++ b/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs @@ -0,0 +1,12 @@ +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Core.SecretsManager.Repositories; + +public interface ISecretVersionRepository +{ + Task GetByIdAsync(Guid id); + Task> GetManyBySecretIdAsync(Guid secretId); + Task> GetManyByIdsAsync(IEnumerable ids); + Task CreateAsync(SecretVersion secretVersion); + Task DeleteManyByIdAsync(IEnumerable ids); +} diff --git a/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs b/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs new file mode 100644 index 0000000000..caa5d96a7c --- /dev/null +++ b/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs @@ -0,0 +1,31 @@ +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Core.SecretsManager.Repositories.Noop; + +public class NoopSecretVersionRepository : ISecretVersionRepository +{ + public Task GetByIdAsync(Guid id) + { + return Task.FromResult(null as SecretVersion); + } + + public Task> GetManyBySecretIdAsync(Guid secretId) + { + return Task.FromResult(Enumerable.Empty()); + } + + public Task CreateAsync(SecretVersion secretVersion) + { + return Task.FromResult(secretVersion); + } + + public Task DeleteManyByIdAsync(IEnumerable ids) + { + return Task.CompletedTask; + } + + public Task> GetManyByIdsAsync(IEnumerable ids) + { + return Task.FromResult(Enumerable.Empty()); + } +} diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs deleted file mode 100644 index 6b2c3c299e..0000000000 --- a/src/Core/Services/IStripeAdapter.cs +++ /dev/null @@ -1,54 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; -using Stripe.Tax; - -namespace Bit.Core.Services; - -public interface IStripeAdapter -{ - Task CustomerCreateAsync(CustomerCreateOptions customerCreateOptions); - Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); - Task CustomerGetAsync(string id, CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task> CustomerListPaymentMethods(string id, CustomerPaymentMethodListOptions options = null); - Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options); - Task SubscriptionCreateAsync(SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null); - Task SubscriptionUpdateAsync(string id, SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null); - Task InvoiceGetAsync(string id, InvoiceGetOptions options); - Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); - Task> InvoiceSearchAsync(InvoiceSearchOptions options); - Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options); - Task InvoicePayAsync(string id, InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options); - IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null); - Task TaxIdCreateAsync(string id, TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); - Task> TaxRegistrationsListAsync(RegistrationListOptions options = null); - Task> ChargeListAsync(ChargeListOptions options); - Task RefundCreateAsync(RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); - Task> PriceListAsync(PriceListOptions options = null); - Task SetupIntentCreate(SetupIntentCreateOptions options); - Task> SetupIntentList(SetupIntentListOptions options); - Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null); - Task SetupIntentGet(string id, SetupIntentGetOptions options = null); - Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options); - Task> TestClockListAsync(); - Task PriceGetAsync(string id, PriceGetOptions options = null); -} diff --git a/src/Core/Services/IStripeSyncService.cs b/src/Core/Services/IStripeSyncService.cs deleted file mode 100644 index 655998805e..0000000000 --- a/src/Core/Services/IStripeSyncService.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Core.Services; - -public interface IStripeSyncService -{ - Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); -} diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 412f9db36e..fade63de51 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -4,7 +4,6 @@ using System.Security.Claims; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Entities; using Bit.Core.Enums; @@ -60,11 +59,23 @@ public interface IUserService Task CheckPasswordAsync(User user, string password); /// /// Checks if the user has access to premium features, either through a personal subscription or through an organization. + /// + /// This is the preferred way to definitively know if a user has access to premium features when you already have the User object. /// /// user being acted on /// true if they can access premium; false otherwise. - Task CanAccessPremium(ITwoFactorProvidersUser user); - Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); + Task CanAccessPremium(User user); + + /// + /// Checks if the user has inherited access to premium features through an organization. + /// + /// This primarily serves as a means to communicate to the client when a user has inherited their premium status + /// through an organization. Feature gating logic probably should not be behind this check. + /// + /// user being acted on + /// true if they can access premium because of organization membership; false otherwise. + [Obsolete("Use IHasPremiumAccessQuery.HasPremiumFromOrganizationAsync instead. This method will be removed in a future version.")] + Task HasPremiumFromOrganization(User user); Task GenerateSignInTokenAsync(User user, string purpose); Task UpdatePasswordHash(User user, string newPassword, diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs deleted file mode 100644 index 3d1663f021..0000000000 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ /dev/null @@ -1,284 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; -using Stripe.Tax; - -namespace Bit.Core.Services; - -public class StripeAdapter : IStripeAdapter -{ - private readonly CustomerService _customerService; - private readonly SubscriptionService _subscriptionService; - private readonly InvoiceService _invoiceService; - private readonly PaymentMethodService _paymentMethodService; - private readonly TaxIdService _taxIdService; - private readonly ChargeService _chargeService; - private readonly RefundService _refundService; - private readonly CardService _cardService; - private readonly BankAccountService _bankAccountService; - private readonly PlanService _planService; - private readonly PriceService _priceService; - private readonly SetupIntentService _setupIntentService; - private readonly Stripe.TestHelpers.TestClockService _testClockService; - private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; - private readonly Stripe.Tax.RegistrationService _taxRegistrationService; - private readonly CalculationService _calculationService; - - public StripeAdapter() - { - _customerService = new CustomerService(); - _subscriptionService = new SubscriptionService(); - _invoiceService = new InvoiceService(); - _paymentMethodService = new PaymentMethodService(); - _taxIdService = new TaxIdService(); - _chargeService = new ChargeService(); - _refundService = new RefundService(); - _cardService = new CardService(); - _bankAccountService = new BankAccountService(); - _priceService = new PriceService(); - _planService = new PlanService(); - _setupIntentService = new SetupIntentService(); - _testClockService = new Stripe.TestHelpers.TestClockService(); - _customerBalanceTransactionService = new CustomerBalanceTransactionService(); - _taxRegistrationService = new Stripe.Tax.RegistrationService(); - _calculationService = new CalculationService(); - } - - public Task CustomerCreateAsync(CustomerCreateOptions options) - { - return _customerService.CreateAsync(options); - } - - public Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => - _customerService.DeleteDiscountAsync(customerId, options); - - public Task CustomerGetAsync(string id, CustomerGetOptions options = null) - { - return _customerService.GetAsync(id, options); - } - - public Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null) - { - return _customerService.UpdateAsync(id, options); - } - - public Task CustomerDeleteAsync(string id) - { - return _customerService.DeleteAsync(id); - } - - public async Task> CustomerListPaymentMethods(string id, - CustomerPaymentMethodListOptions options = null) - { - var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); - return paymentMethods.Data; - } - - public async Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options) - => await _customerBalanceTransactionService.CreateAsync(customerId, options); - - public Task SubscriptionCreateAsync(SubscriptionCreateOptions options) - { - return _subscriptionService.CreateAsync(options); - } - - public Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null) - { - return _subscriptionService.GetAsync(id, options); - } - - public async Task ProviderSubscriptionGetAsync( - string id, - Guid providerId, - SubscriptionGetOptions options = null) - { - var subscription = await _subscriptionService.GetAsync(id, options); - if (subscription.Metadata.TryGetValue("providerId", out var value) && value == providerId.ToString()) - { - return subscription; - } - - throw new InvalidOperationException("Subscription does not belong to the provider."); - } - - public Task SubscriptionUpdateAsync(string id, - SubscriptionUpdateOptions options = null) - { - return _subscriptionService.UpdateAsync(id, options); - } - - public Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null) - { - return _subscriptionService.CancelAsync(Id, options); - } - - public Task InvoiceGetAsync(string id, InvoiceGetOptions options) - { - return _invoiceService.GetAsync(id, options); - } - - public async Task> InvoiceListAsync(StripeInvoiceListOptions options) - { - if (!options.SelectAll) - { - return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; - } - - options.Limit = 100; - - var invoices = new List(); - - await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) - { - invoices.Add(invoice); - } - - return invoices; - } - - public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options) - { - return _invoiceService.CreatePreviewAsync(options); - } - - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) - => (await _invoiceService.SearchAsync(options)).Data; - - public Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options) - { - return _invoiceService.UpdateAsync(id, options); - } - - public Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) - { - return _invoiceService.FinalizeInvoiceAsync(id, options); - } - - public Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options) - { - return _invoiceService.SendInvoiceAsync(id, options); - } - - public Task InvoicePayAsync(string id, InvoicePayOptions options = null) - { - return _invoiceService.PayAsync(id, options); - } - - public Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null) - { - return _invoiceService.DeleteAsync(id, options); - } - - public Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null) - { - return _invoiceService.VoidInvoiceAsync(id, options); - } - - public IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options) - { - return _paymentMethodService.ListAutoPaging(options); - } - - public IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options) - => _paymentMethodService.ListAutoPagingAsync(options); - - public Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null) - { - return _paymentMethodService.AttachAsync(id, options); - } - - public Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null) - { - return _paymentMethodService.DetachAsync(id, options); - } - - public Task PlanGetAsync(string id, PlanGetOptions options = null) - { - return _planService.GetAsync(id, options); - } - - public Task TaxIdCreateAsync(string id, TaxIdCreateOptions options) - { - return _taxIdService.CreateAsync(id, options); - } - - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - TaxIdDeleteOptions options = null) - { - return _taxIdService.DeleteAsync(customerId, taxIdId); - } - - public Task> TaxRegistrationsListAsync(RegistrationListOptions options = null) - { - return _taxRegistrationService.ListAsync(options); - } - - public Task> ChargeListAsync(ChargeListOptions options) - { - return _chargeService.ListAsync(options); - } - - public Task RefundCreateAsync(RefundCreateOptions options) - { - return _refundService.CreateAsync(options); - } - - public Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null) - { - return _cardService.DeleteAsync(customerId, cardId, options); - } - - public Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null) - { - return _bankAccountService.CreateAsync(customerId, options); - } - - public Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) - { - return _bankAccountService.DeleteAsync(customerId, bankAccount, options); - } - - public async Task> PriceListAsync(PriceListOptions options = null) - { - return await _priceService.ListAsync(options); - } - - public Task SetupIntentCreate(SetupIntentCreateOptions options) - => _setupIntentService.CreateAsync(options); - - public async Task> SetupIntentList(SetupIntentListOptions options) - { - var setupIntents = await _setupIntentService.ListAsync(options); - - return setupIntents.Data; - } - - public Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null) - => _setupIntentService.CancelAsync(id, options); - - public Task SetupIntentGet(string id, SetupIntentGetOptions options = null) - => _setupIntentService.GetAsync(id, options); - - public Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options) - => _setupIntentService.VerifyMicrodepositsAsync(id, options); - - public async Task> TestClockListAsync() - { - var items = new List(); - var options = new Stripe.TestHelpers.TestClockListOptions() - { - Limit = 100 - }; - await foreach (var i in _testClockService.ListAutoPagingAsync(options)) - { - items.Add(i); - } - return items; - } - - public Task PriceGetAsync(string id, PriceGetOptions options = null) - => _priceService.GetAsync(id, options); -} diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index daf1b2078d..8db66211b1 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -17,6 +17,7 @@ using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Premium.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; @@ -57,7 +58,7 @@ public class UserService : UserManager, IUserService private readonly ILicensingService _licenseService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly IFido2 _fido2; @@ -73,6 +74,7 @@ public class UserService : UserManager, IUserService private readonly IDistributedCache _distributedCache; private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPricingClient _pricingClient; + private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery; public UserService( IUserRepository userRepository, @@ -93,7 +95,7 @@ public class UserService : UserManager, IUserService ILicensingService licenseService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, IFido2 fido2, @@ -108,7 +110,8 @@ public class UserService : UserManager, IUserService ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IDistributedCache distributedCache, IPolicyRequirementQuery policyRequirementQuery, - IPricingClient pricingClient) + IPricingClient pricingClient, + IHasPremiumAccessQuery hasPremiumAccessQuery) : base( store, optionsAccessor, @@ -149,6 +152,7 @@ public class UserService : UserManager, IUserService _distributedCache = distributedCache; _policyRequirementQuery = policyRequirementQuery; _pricingClient = pricingClient; + _hasPremiumAccessQuery = hasPremiumAccessQuery; } public Guid? GetProperUserId(ClaimsPrincipal principal) @@ -534,7 +538,7 @@ public class UserService : UserManager, IUserService try { - await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, + await _stripeSyncService.UpdateCustomerEmailAddressAsync(user.GatewayCustomerId, user.BillingEmailAddress()); } catch (Exception ex) @@ -867,7 +871,7 @@ public class UserService : UserManager, IUserService } string paymentIntentClientSecret = null; - IPaymentService paymentService = null; + IStripePaymentService paymentService = null; if (_globalSettings.SelfHosted) { if (license == null || !_licenseService.VerifyLicense(license)) @@ -904,7 +908,6 @@ public class UserService : UserManager, IUserService } else { - user.MaxStorageGb = (short)(1 + additionalStorageGb); user.LicenseKey = CoreHelpers.SecureRandomString(20); } @@ -977,7 +980,8 @@ public class UserService : UserManager, IUserService var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, premiumPlan.Storage.StripePriceId); + var baseStorageGb = (short)premiumPlan.Storage.Provided; + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, premiumPlan.Storage.StripePriceId, baseStorageGb); await SaveUserAsync(user); return secret; } @@ -1104,7 +1108,7 @@ public class UserService : UserManager, IUserService return success; } - public async Task CanAccessPremium(ITwoFactorProvidersUser user) + public async Task CanAccessPremium(User user) { var userId = user.GetUserId(); if (!userId.HasValue) @@ -1112,10 +1116,15 @@ public class UserService : UserManager, IUserService return false; } - return user.GetPremium() || await this.HasPremiumFromOrganization(user); + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return user.Premium || await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value); + } + + return user.Premium || await HasPremiumFromOrganization(user); } - public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) + public async Task HasPremiumFromOrganization(User user) { var userId = user.GetUserId(); if (!userId.HasValue) @@ -1123,6 +1132,11 @@ public class UserService : UserManager, IUserService return false; } + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value); + } + // orgUsers in the Invited status are not associated with a userId yet, so this will get // orgUsers in Accepted and Confirmed states only var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); @@ -1138,6 +1152,7 @@ public class UserService : UserManager, IUserService orgAbility.UsersGetPremium && orgAbility.Enabled); } + public async Task GenerateSignInTokenAsync(User user, string purpose) { var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index c467d1e652..f030c73809 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -2,14 +2,12 @@ #nullable disable using Bit.Core.Auth.Settings; -using Bit.Core.Settings.LoggingSettings; namespace Bit.Core.Settings; public class GlobalSettings : IGlobalSettings { private string _mailTemplateDirectory; - private string _logDirectory; private string _licenseDirectory; public GlobalSettings() @@ -21,18 +19,10 @@ public class GlobalSettings : IGlobalSettings } public bool SelfHosted { get; set; } - public bool UnifiedDeployment { get; set; } + public bool LiteDeployment { get; set; } public virtual string KnownProxies { get; set; } public virtual string SiteName { get; set; } public virtual string ProjectName { get; set; } - public virtual string LogDirectory - { - get => BuildDirectory(_logDirectory, "/logs"); - set => _logDirectory = value; - } - public virtual bool LogDirectoryByProject { get; set; } = true; - public virtual long? LogRollBySizeLimit { get; set; } - public virtual bool EnableDevLogging { get; set; } = false; public virtual string LicenseDirectory { get => BuildDirectory(_licenseDirectory, "/core/licenses"); @@ -66,16 +56,13 @@ public class GlobalSettings : IGlobalSettings public virtual EventLoggingSettings EventLogging { get; set; } = new EventLoggingSettings(); public virtual MailSettings Mail { get; set; } = new MailSettings(); public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); - public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); + public virtual AzureQueueEventSettings Events { get; set; } = new AzureQueueEventSettings(); public virtual DistributedCacheSettings DistributedCache { get; set; } = new DistributedCacheSettings(); public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); public virtual IFileStorageSettings Attachment { get; set; } public virtual FileStorageSettings Send { get; set; } public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); public virtual DataProtectionSettings DataProtection { get; set; } - public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); - public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); - public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings(); public virtual NotificationHubPoolSettings NotificationHubPool { get; set; } = new(); public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); public virtual DuoSettings Duo { get; set; } = new DuoSettings(); @@ -94,7 +81,6 @@ public class GlobalSettings : IGlobalSettings public virtual ILaunchDarklySettings LaunchDarkly { get; set; } = new LaunchDarklySettings(); public virtual string DevelopmentDirectory { get; set; } public virtual IWebPushSettings WebPush { get; set; } = new WebPushSettings(); - public virtual IPhishingDomainSettings PhishingDomain { get; set; } = new PhishingDomainSettings(); public virtual int SendAccessTokenLifetimeInMinutes { get; set; } = 5; public virtual bool EnableEmailVerification { get; set; } @@ -408,6 +394,24 @@ public class GlobalSettings : IGlobalSettings } } + public class AzureQueueEventSettings : IConnectionStringSettings + { + private string _connectionString; + private string _queueName; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value?.Trim('"'); + } + + public string QueueName + { + get => _queueName; + set => _queueName = value?.Trim('"'); + } + } + public class ConnectionStringSettings : IConnectionStringSettings { private string _connectionString; @@ -496,7 +500,7 @@ public class GlobalSettings : IGlobalSettings public string CertificatePassword { get; set; } public string RedisConnectionString { get; set; } public string CosmosConnectionString { get; set; } - public string LicenseKey { get; set; } = "eyJhbGciOiJQUzI1NiIsImtpZCI6IklkZW50aXR5U2VydmVyTGljZW5zZWtleS83Y2VhZGJiNzgxMzA0NjllODgwNjg5MTAyNTQxNGYxNiIsInR5cCI6ImxpY2Vuc2Urand0In0.eyJpc3MiOiJodHRwczovL2R1ZW5kZXNvZnR3YXJlLmNvbSIsImF1ZCI6IklkZW50aXR5U2VydmVyIiwiaWF0IjoxNzM0NTY2NDAwLCJleHAiOjE3NjQ5NzkyMDAsImNvbXBhbnlfbmFtZSI6IkJpdHdhcmRlbiBJbmMuIiwiY29udGFjdF9pbmZvIjoiY29udGFjdEBkdWVuZGVzb2Z0d2FyZS5jb20iLCJlZGl0aW9uIjoiU3RhcnRlciIsImlkIjoiNjg3OCIsImZlYXR1cmUiOlsiaXN2IiwidW5saW1pdGVkX2NsaWVudHMiXSwicHJvZHVjdCI6IkJpdHdhcmRlbiJ9.TYc88W_t2t0F2AJV3rdyKwGyQKrKFriSAzm1tWFNHNR9QizfC-8bliGdT4Wgeie-ynCXs9wWaF-sKC5emg--qS7oe2iIt67Qd88WS53AwgTvAddQRA4NhGB1R7VM8GAikLieSos-DzzwLYRgjZdmcsprItYGSJuY73r-7-F97ta915majBytVxGF966tT9zF1aYk0bA8FS6DcDYkr5f7Nsy8daS_uIUAgNa_agKXtmQPqKujqtUb6rgWEpSp4OcQcG-8Dpd5jHqoIjouGvY-5LTgk5WmLxi_m-1QISjxUJrUm-UGao3_VwV5KFGqYrz8csdTl-HS40ihWcsWnrV0ug"; + public string LicenseKey { get; set; } = "eyJhbGciOiJQUzI1NiIsImtpZCI6IklkZW50aXR5U2VydmVyTGljZW5zZUtleS83Y2VhZGJiNzgxMzA0NjllODgwNjg5MTAyNTQxNGYxNiIsInR5cCI6ImxpY2Vuc2Urand0In0.eyJpc3MiOiJodHRwczovL2R1ZW5kZXNvZnR3YXJlLmNvbSIsImF1ZCI6IklkZW50aXR5U2VydmVyIiwiaWF0IjoxNzY1MDY1NjAwLCJleHAiOjE3OTY1MTUyMDAsImNvbXBhbnlfbmFtZSI6IkJpdHdhcmRlbiBJbmMuIiwiY29udGFjdF9pbmZvIjoiY29udGFjdEBkdWVuZGVzb2Z0d2FyZS5jb20iLCJlZGl0aW9uIjoiU3RhcnRlciIsImlkIjoiOTUxNSIsImZlYXR1cmUiOlsiaXN2IiwidW5saW1pdGVkX2NsaWVudHMiXSwiY2xpZW50X2xpbWl0IjowfQ.rWUsq-XBKNwPG7BRKG-vShXHuyHLHJCh0sEWdWT4Rkz4ArIPOAepEp9wNya-hxFKkBTFlPaQ5IKk4wDTvkQkuq1qaI_v6kSCdaP9fvXp0rmh4KcFEffVLB-wAOK2S2Cld5DzdyCoskUUfwNQP7xuLsz2Ydxe_whSRIdv8bsMbvTC3Kl8PYZPZ4MxqW8rSZ_mEuCpSe5-Q40sB7aiu_7YmWLJaKrfBTIqYH-XuzQj36Aemoei0efcntej-gvxovy-5SiSEsGuRZj41rjEZYOuj5KgHihJViO1VDHK6CNtlu2Ks8bkv6G2hO-TkF16Y28ywEG_beLEf_s5dzhbDBDbvA"; /// /// Sliding lifetime of a refresh token in seconds. /// @@ -548,59 +552,11 @@ public class GlobalSettings : IGlobalSettings } } - public class SentrySettings - { - public string Dsn { get; set; } - } - public class NotificationsSettings : ConnectionStringSettings { public string RedisConnectionString { get; set; } } - public class SyslogSettings - { - /// - /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. - /// - /// - /// The connection string will be parsed using to extract the protocol, host name and port number. - /// - /// - /// Supported protocols are: - /// - /// UDP (use udp://) - /// TCP (use tcp://) - /// TLS over TCP (use tls://) - /// - /// - /// - /// - /// A remote server (logging.dev.example.com) is listening on UDP (port 514): - /// - /// udp://logging.dev.example.com:514. - /// - public string Destination { get; set; } - /// - /// The absolute path to a Certificate (DER or Base64 encoded with private key). - /// - /// - /// The certificate path and are passed into the . - /// The file format of the certificate may be binary encoded (DER) or base64. If the private key is encrypted, provide the password in , - /// - public string CertificatePath { get; set; } - /// - /// The password for the encrypted private key in the certificate supplied in . - /// - /// - public string CertificatePassword { get; set; } - /// - /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. - /// - /// - public string CertificateThumbprint { get; set; } - } - public class NotificationHubSettings { private string _connectionString; @@ -733,12 +689,6 @@ public class GlobalSettings : IGlobalSettings public int MaxNetworkRetries { get; set; } = 2; } - public class PhishingDomainSettings : IPhishingDomainSettings - { - public string UpdateUrl { get; set; } - public string ChecksumUrl { get; set; } - } - public class DistributedIpRateLimitingSettings { public string RedisConnectionString { get; set; } @@ -783,6 +733,30 @@ public class GlobalSettings : IGlobalSettings { public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); public virtual IConnectionStringSettings Cosmos { get; set; } = new ConnectionStringSettings(); + public ExtendedCacheSettings DefaultExtendedCache { get; set; } = new ExtendedCacheSettings(); + } + + /// + /// A collection of Settings for customizing the FusionCache used in extended caching. Defaults are + /// provided for every attribute so that only specific values need to be overridden if needed. + /// + public class ExtendedCacheSettings + { + public bool EnableDistributedCache { get; set; } = true; + public bool UseSharedDistributedCache { get; set; } = true; + public IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); + public TimeSpan Duration { get; set; } = TimeSpan.FromMinutes(30); + public bool IsFailSafeEnabled { get; set; } = true; + public TimeSpan FailSafeMaxDuration { get; set; } = TimeSpan.FromHours(2); + public TimeSpan FailSafeThrottleDuration { get; set; } = TimeSpan.FromSeconds(30); + public float? EagerRefreshThreshold { get; set; } = 0.9f; + public TimeSpan FactorySoftTimeout { get; set; } = TimeSpan.FromMilliseconds(100); + public TimeSpan FactoryHardTimeout { get; set; } = TimeSpan.FromMilliseconds(1500); + public TimeSpan DistributedCacheSoftTimeout { get; set; } = TimeSpan.FromSeconds(1); + public TimeSpan DistributedCacheHardTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool AllowBackgroundDistributedCacheOperations { get; set; } = true; + public TimeSpan JitterMaxDuration { get; set; } = TimeSpan.FromSeconds(2); + public TimeSpan DistributedCacheCircuitBreakerDuration { get; set; } = TimeSpan.FromSeconds(30); } public class WebPushSettings : IWebPushSettings diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index d77842373e..06dece3394 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -6,7 +6,7 @@ public interface IGlobalSettings { // This interface exists for testing. Add settings here as needed for testing bool SelfHosted { get; set; } - bool UnifiedDeployment { get; set; } + bool LiteDeployment { get; set; } string KnownProxies { get; set; } string ProjectName { get; set; } bool EnableCloudCommunication { get; set; } @@ -20,7 +20,6 @@ public interface IGlobalSettings IConnectionStringSettings Storage { get; set; } IBaseServiceUriSettings BaseServiceUri { get; set; } ISsoSettings Sso { get; set; } - ILogLevelSettings MinLogLevel { get; set; } IPasswordlessAuthSettings PasswordlessAuth { get; set; } IDomainVerificationSettings DomainVerification { get; set; } ILaunchDarklySettings LaunchDarkly { get; set; } @@ -29,5 +28,4 @@ public interface IGlobalSettings string DevelopmentDirectory { get; set; } IWebPushSettings WebPush { get; set; } GlobalSettings.EventLoggingSettings EventLogging { get; set; } - IPhishingDomainSettings PhishingDomain { get; set; } } diff --git a/src/Core/Settings/ILogLevelSettings.cs b/src/Core/Settings/ILogLevelSettings.cs deleted file mode 100644 index b3cedf083c..0000000000 --- a/src/Core/Settings/ILogLevelSettings.cs +++ /dev/null @@ -1,74 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings; - -public interface ILogLevelSettings -{ - IBillingLogLevelSettings BillingSettings { get; set; } - IApiLogLevelSettings ApiSettings { get; set; } - IIdentityLogLevelSettings IdentitySettings { get; set; } - IScimLogLevelSettings ScimSettings { get; set; } - ISsoLogLevelSettings SsoSettings { get; set; } - IAdminLogLevelSettings AdminSettings { get; set; } - IEventsLogLevelSettings EventsSettings { get; set; } - IEventsProcessorLogLevelSettings EventsProcessorSettings { get; set; } - IIconsLogLevelSettings IconsSettings { get; set; } - INotificationsLogLevelSettings NotificationsSettings { get; set; } -} - -public interface IBillingLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel Jobs { get; set; } -} - -public interface IApiLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } - LogEventLevel IpRateLimit { get; set; } -} - -public interface IIdentityLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } - LogEventLevel IpRateLimit { get; set; } -} - -public interface IScimLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface ISsoLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IAdminLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IEventsLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } -} - -public interface IEventsProcessorLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IIconsLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface INotificationsLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } -} diff --git a/src/Core/Settings/IPhishingDomainSettings.cs b/src/Core/Settings/IPhishingDomainSettings.cs deleted file mode 100644 index 2e4a901a5a..0000000000 --- a/src/Core/Settings/IPhishingDomainSettings.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.Settings; - -public interface IPhishingDomainSettings -{ - string UpdateUrl { get; set; } - string ChecksumUrl { get; set; } -} diff --git a/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs deleted file mode 100644 index d2c74dd076..0000000000 --- a/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class AdminLogLevelSettings : IAdminLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs deleted file mode 100644 index 7961ab7e3b..0000000000 --- a/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class ApiLogLevelSettings : IApiLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; - public LogEventLevel IpRateLimit { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs deleted file mode 100644 index b9e53e6bca..0000000000 --- a/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class BillingLogLevelSettings : IBillingLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; - public LogEventLevel Jobs { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs deleted file mode 100644 index 3201748550..0000000000 --- a/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class EventsLogLevelSettings : IEventsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; -} diff --git a/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs deleted file mode 100644 index 5aff18a216..0000000000 --- a/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class EventsProcessorLogLevelSettings : IEventsProcessorLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; -} diff --git a/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs deleted file mode 100644 index c7b73ba687..0000000000 --- a/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class IconsLogLevelSettings : IIconsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs deleted file mode 100644 index a823cb5109..0000000000 --- a/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class IdentityLogLevelSettings : IIdentityLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; - public LogEventLevel IpRateLimit { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/LogLevelSettings.cs b/src/Core/Settings/LoggingSettings/LogLevelSettings.cs deleted file mode 100644 index 1af05ebfde..0000000000 --- a/src/Core/Settings/LoggingSettings/LogLevelSettings.cs +++ /dev/null @@ -1,16 +0,0 @@ - -namespace Bit.Core.Settings.LoggingSettings; - -public class LogLevelSettings : ILogLevelSettings -{ - public IBillingLogLevelSettings BillingSettings { get; set; } = new BillingLogLevelSettings(); - public IApiLogLevelSettings ApiSettings { get; set; } = new ApiLogLevelSettings(); - public IIdentityLogLevelSettings IdentitySettings { get; set; } = new IdentityLogLevelSettings(); - public IScimLogLevelSettings ScimSettings { get; set; } = new ScimLogLevelSettings(); - public ISsoLogLevelSettings SsoSettings { get; set; } = new SsoLogLevelSettings(); - public IAdminLogLevelSettings AdminSettings { get; set; } = new AdminLogLevelSettings(); - public IEventsLogLevelSettings EventsSettings { get; set; } = new EventsLogLevelSettings(); - public IEventsProcessorLogLevelSettings EventsProcessorSettings { get; set; } = new EventsProcessorLogLevelSettings(); - public IIconsLogLevelSettings IconsSettings { get; set; } = new IconsLogLevelSettings(); - public INotificationsLogLevelSettings NotificationsSettings { get; set; } = new NotificationsLogLevelSettings(); -} diff --git a/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs deleted file mode 100644 index 3494fbfcca..0000000000 --- a/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class NotificationsLogLevelSettings : INotificationsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; -} diff --git a/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs deleted file mode 100644 index f297b17e95..0000000000 --- a/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class ScimLogLevelSettings : IScimLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; -} diff --git a/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs deleted file mode 100644 index 495ec41fd0..0000000000 --- a/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class SsoLogLevelSettings : ISsoLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index c7f7e3aff7..fa558f5963 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -150,17 +150,34 @@ public class ImportCiphersCommand : IImportCiphersCommand foreach (var collection in collections) { - if (!organizationCollectionsIds.Contains(collection.Id)) + // If the collection already exists, skip it + if (organizationCollectionsIds.Contains(collection.Id)) { - collection.SetNewId(); - newCollections.Add(collection); - newCollectionUsers.Add(new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = importingOrgUser.Id, - Manage = true - }); + continue; } + + // Create new collections if not already present + collection.SetNewId(); + newCollections.Add(collection); + + /* + * If the organization was created by a Provider, the organization may have zero members (users) + * In this situation importingOrgUser will be null, and accessing importingOrgUser.Id will + * result in a null reference exception. + * + * Avoid user assignment, but proceed with adding the collection. + */ + if (importingOrgUser == null) + { + continue; + } + + newCollectionUsers.Add(new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = importingOrgUser.Id, + Manage = true + }); } // Create associations based on the newly assigned ids diff --git a/src/Core/Utilities/BillingHelpers.cs b/src/Core/Utilities/BillingHelpers.cs index e7ccfc3547..ef0fdf010b 100644 --- a/src/Core/Utilities/BillingHelpers.cs +++ b/src/Core/Utilities/BillingHelpers.cs @@ -1,13 +1,13 @@ -using Bit.Core.Entities; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; using Bit.Core.Exceptions; -using Bit.Core.Services; namespace Bit.Core.Utilities; public static class BillingHelpers { - internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, - short storageAdjustmentGb, string storagePlanId) + internal static async Task AdjustStorageAsync(IStripePaymentService paymentService, IStorableSubscriber storableSubscriber, + short storageAdjustmentGb, string storagePlanId, short baseStorageGb) { if (storableSubscriber == null) { @@ -30,9 +30,9 @@ public static class BillingHelpers } var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); - if (newStorageGb < 1) + if (newStorageGb < baseStorageGb) { - newStorageGb = 1; + newStorageGb = baseStorageGb; } if (newStorageGb > 100) @@ -48,7 +48,7 @@ public static class BillingHelpers "Delete some stored data first."); } - var additionalStorage = newStorageGb - 1; + var additionalStorage = newStorageGb - baseStorageGb; var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, additionalStorage, storagePlanId); storableSubscriber.MaxStorageGb = newStorageGb; diff --git a/src/Core/Utilities/CACHING.md b/src/Core/Utilities/CACHING.md new file mode 100644 index 0000000000..c29a14d751 --- /dev/null +++ b/src/Core/Utilities/CACHING.md @@ -0,0 +1,1123 @@ +# Bitwarden Server Caching + +Caching options available in Bitwarden's server. The server uses multiple caching layers and backends to balance performance, scalability, and operational simplicity across both cloud and self-hosted deployments. + +--- + +## Choosing a Caching Option + +Use this decision tree to identify the appropriate caching option for your feature: + +``` +Does your data need to be shared across all instances in a horizontally-scaled deployment? +├─ YES +│ │ +│ Do you need long-term persistence with TTL (days/weeks)? +│ ├─ YES → Use `IDistributedCache` with persistent keyed service +│ └─ NO → Use `ExtendedCache` +│ │ +│ Notes: +│ - With Redis configured: memory + distributed + backplane +│ - Without Redis: memory-only with stampede protection +│ - Provides fail-safe, eager refresh, circuit breaker +│ - For org/provider abilities: Use GetOrSetAsync with preloading pattern +│ +└─ NO (single instance or manual sync acceptable) + │ + Use `ExtendedCache` with memory-only mode (EnableDistributedCache = false) + │ + Notes: + - Same performance as raw IMemoryCache + - Built-in stampede protection, eager refresh, fail-safe + - "Free" Redis/backplane if needed at a later date (but not required) + - Only use specialized in-memory cache if ExtendedCache API doesn't fit + +*Stampede protection = prevents cache stampedes (multiple simultaneous requests for the same expired/missing key triggering redundant backend calls) +``` + +--- + +## Caching Options Overview + +| Option | Best For | Horizontal Scale | TTL Support | Backend Options | +| -------------------------------------- | ---------------------------------------------- | ---------------- | ----------- | ---------------------- | +| **ExtendedCache** | General-purpose caching with advanced features | ✅ Yes | ✅ Yes | Redis, Memory | +| **IDistributedCache** (default) | Short-lived key-value caching | ✅ Yes | ⚠️ Manual | Redis, SQL, EF | +| **IDistributedCache** (`"persistent"`) | Long-lived data with TTL | ✅ Yes | ✅ Yes | Cosmos, Redis, SQL, EF | +| **In-Memory Cache** | High-frequency reads, single instance | ❌ No | ⚠️ Manual | Memory | + +--- + +## `ExtendedCache` + +`ExtendedCache` is a wrapper around [FusionCache](https://github.com/ZiggyCreatures/FusionCache) that provides a simple way to register **named, isolated caches** with sensible defaults. The goal is to make it trivial for each subsystem or feature to have its own cache - with optional distributed caching and backplane support - without repeatedly wiring up FusionCache, Redis, and related infrastructure. + +Each named cache automatically receives: + +- Its own `FusionCache` instance +- Its own configuration (default or overridden) +- Its own key prefix +- Optional distributed store +- Optional backplane + +`ExtendedCache` supports three deployment modes: + +- **Memory-only caching** (with stampede protection: prevents multiple concurrent requests for the same key from hitting the backend) +- **Memory + distributed cache + backplane** using the **shared** application Redis +- **Memory + distributed cache + backplane** using a **fully isolated** Redis instance + +### When to Use + +- **General-purpose caching** for any domain data +- Features requiring **stampede protection** (when multiple concurrent requests for the same cache key should result in only a single backend call, with all requesters waiting for the same result) +- Data that benefits from **fail-safe mode** (serve stale data on backend failures) +- Multi-instance applications requiring **cache synchronization** via backplane +- You want **isolated cache configuration** per feature + +### Pros + +✅ **Advanced features out-of-the-box**: + +- Stampede protection (multiple requests for same key = single backend call) +- Fail-safe mode with stale data serving +- Adaptive caching with eager refresh +- Automatic backplane support for multi-instance sync +- Circuit breaker for backend failures + +✅ **Named, isolated caches**: Each feature gets its own cache instance with independent configuration + +✅ **Flexible deployment modes**: + +- Memory-only (development, testing) +- Memory + Redis (production cloud) +- Memory + isolated Redis (specialized features) + +✅ **Simple API**: Uses `FusionCache`'s intuitive `GetOrSet` pattern + +✅ **Built-in serialization**: Automatic JSON serialization/deserialization + +### Cons + +❌ Requires understanding of `FusionCache` configuration options + +❌ Slightly more overhead than raw `IDistributedCache` + +❌ IDistributedCache dependency for multi-instance deployments (typically Redis, but degrades gracefully to memory-only) + +### Example Usage + +**Note**: When using the shared Redis cache option (which is on by default, if the Redis connection string is configured), it is expected to call `services.AddDistributedCache(globalSettings)` **before** calling `AddExtendedCache`. The idea is to set up the distributed cache in our normal pattern and then "extend" it to include more functionality. + +#### 1. Register the cache (in Startup.cs): + +```csharp +// Option 1: Use default settings with shared Redis (if available) +services.AddDistributedCache(globalSettings); +services.AddExtendedCache("MyFeatureCache", globalSettings); + +// Option 2: Memory-only mode for high-performance single-instance caching +services.AddExtendedCache("MyFeatureCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + EnableDistributedCache = false, // Memory-only, same performance as IMemoryCache + Duration = TimeSpan.FromHours(1), + IsFailSafeEnabled = true, + EagerRefreshThreshold = 0.9 // Refresh at 90% of TTL +}); +// When EnableDistributedCache = false: +// - Uses memory-only caching (same performance as raw IMemoryCache) +// - Still provides stampede protection, eager refresh, fail-safe +// - Redis/backplane can be enabled later by setting EnableDistributedCache = true + +// Option 3: Override default settings with Redis +services.AddExtendedCache("MyFeatureCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromHours(1), + IsFailSafeEnabled = true, + FailSafeMaxDuration = TimeSpan.FromHours(2), + EagerRefreshThreshold = 0.9 // Refresh at 90% of TTL +}); + +// Option 4: Isolated Redis for specialized features +services.AddExtendedCache("SpecializedCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + UseSharedDistributedCache = false, + Redis = new GlobalSettings.ConnectionStringSettings + { + ConnectionString = "localhost:6379,ssl=false" + } +}); +// When configured this way: +// - A dedicated IConnectionMultiplexer is created +// - A dedicated IDistributedCache is created +// - A dedicated FusionCache backplane is created +// - All three are exposed to DI as keyed services (using the cache name as service key) +``` + +#### 2. Inject and use the cache: + +A named cache is retrieved via DI using keyed services (similar to how [IHttpClientFactory](https://learn.microsoft.com/en-us/aspnet/core/fundamentals/http-requests?view=aspnetcore-7.0#named-clients) works with named clients): + +```csharp +public class MyService +{ + private readonly IFusionCache _cache; + private readonly IItemRepository _itemRepository; + + // Option A: Inject via keyed service in constructor + public MyService( + [FromKeyedServices("MyFeatureCache")] IFusionCache cache, + IItemRepository itemRepository) + { + _cache = cache; + _itemRepository = itemRepository; + } + + // Option B: Request manually from service provider + // cache = provider.GetRequiredKeyedService(serviceKey: "MyFeatureCache") + + // Option C: Inject IFusionCacheProvider and request the named cache + // (similar to IHttpClientFactory pattern) + public MyService( + IFusionCacheProvider cacheProvider, + IItemRepository itemRepository) + { + _cache = cacheProvider.GetCache("MyFeatureCache"); + _itemRepository = itemRepository; + } + + public async Task GetItemAsync(Guid id) + { + return await _cache.GetOrSetAsync( + $"item:{id}", + async _ => await _itemRepository.GetByIdAsync(id), + options => options.SetDuration(TimeSpan.FromMinutes(30)) + ); + } +} +``` + +`ExtendedCache` doesn't change how `FusionCache` is used in code, which means all the functionality and full `FusionCache` API is available. See the [FusionCache docs](https://github.com/ZiggyCreatures/FusionCache/blob/main/docs/CoreMethods.md) for more details. + +### Specific Example: SSO Authorization Grants + +SSO authorization grants are **ephemeral, short-lived data** (typically ≤5 minutes) used to coordinate authorization flows across horizontally-scaled instances. `ExtendedCache` is ideal for this use case: + +```csharp +services.AddExtendedCache("SsoGrants", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromMinutes(5), + IsFailSafeEnabled = false // Re-initiate flow rather than serve stale grants +}); + +public class SsoAuthorizationService +{ + private readonly IFusionCache _cache; + + public SsoAuthorizationService([FromKeyedServices("SsoGrants")] IFusionCache cache) + { + _cache = cache; + } + + public async Task GetGrantAsync(string authorizationCode) + { + return await _cache.GetOrDefaultAsync($"sso:grant:{authorizationCode}"); + } + + public async Task StoreGrantAsync(string authorizationCode, SsoGrant grant) + { + await _cache.SetAsync($"sso:grant:{authorizationCode}", grant); + } +} +``` + +**Why `ExtendedCache` for SSO grants:** + +- **Not critical if lost**: User can re-initiate SSO flow +- **Lower latency**: Redis backplane is faster than persistent storage +- **Simpler infrastructure**: Reuses existing Redis connection +- **Horizontal scaling**: Redis backplane automatically synchronizes across instances + +### Backend Configuration + +`ExtendedCache` automatically uses the configured backend: + +**Cloud (Bitwarden-hosted)**: + +1. Redis (primary, if `GlobalSettings.DistributedCache.Redis.ConnectionString` configured) +2. Memory-only (fallback if Redis unavailable) + +**Self-hosted**: + +1. Redis (if configured in `appsettings.json`) +2. SQL Server / EF Cache (if `IDistributedCache` is registered and no Redis) +3. Memory-only (default fallback) + +> **Note**: ExtendedCache works seamlessly with any `IDistributedCache` backend. In self-hosted scenarios without Redis, you can configure ExtendedCache to use SQL Server or Entity Framework cache as its distributed layer. This provides local memory caching in front of the database cache, with the option to add Redis later if needed. You won't get the backplane (cross-instance invalidation) without Redis, but you still get stampede protection, eager refresh, and fail-safe mode. + +### Specific Example: Organization/Provider Abilities + +Organization and provider abilities are read extremely frequently (on every request that checks permissions) but change infrequently. `ExtendedCache` is ideal for this access pattern with its eager refresh and Redis backplane support: + +```csharp +services.AddExtendedCache("OrganizationAbilities", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromMinutes(10), + EagerRefreshThreshold = 0.9, // Refresh at 90% of TTL + IsFailSafeEnabled = true, + FailSafeMaxDuration = TimeSpan.FromHours(1) // Serve stale data up to 1 hour on backend failures +}); + +public class OrganizationAbilityService +{ + private readonly IFusionCache _cache; + private readonly IOrganizationRepository _organizationRepository; + + public OrganizationAbilityService( + [FromKeyedServices("OrganizationAbilities")] IFusionCache cache, + IOrganizationRepository organizationRepository) + { + _cache = cache; + _organizationRepository = organizationRepository; + } + + public async Task> GetOrganizationAbilitiesAsync() + { + return await _cache.GetOrSetAsync>( + "all-org-abilities", + async _ => + { + var abilities = await _organizationRepository.GetManyAbilitiesAsync(); + return abilities.ToDictionary(a => a.Id); + } + ); + } + + public async Task GetOrganizationAbilityAsync(Guid orgId) + { + var abilities = await GetOrganizationAbilitiesAsync(); + abilities.TryGetValue(orgId, out var ability); + return ability; + } + + public async Task UpsertOrganizationAbilityAsync(Organization organization) + { + // Update database + await _organizationRepository.ReplaceAsync(organization); + + // Invalidate cache - with Redis backplane, this broadcasts to all instances + await _cache.RemoveAsync("all-org-abilities"); + } +} +``` + +**Why `ExtendedCache` for org/provider abilities:** + +- **High-frequency reads**: Every permission check reads abilities +- **Infrequent writes**: Abilities change rarely +- **Eager refresh**: Automatically refreshes at 90% of TTL to prevent cache misses +- **Fail-safe mode**: Serves stale data if database temporarily unavailable +- **Redis backplane**: Automatically invalidates across all instances when abilities change +- **No Service Bus dependency**: Simpler infrastructure (one Redis instead of Redis + Service Bus) + +### When NOT to Use + +- **Long-term persistent data** (days/weeks) - Use `IDistributedCache` with persistent keyed service for structured TTL support +- **Custom caching logic** - If ExtendedCache's API doesn't fit your use case, consider specialized in-memory cache + +--- + +## `IDistributedCache` + +`IDistributedCache` provides two service registrations for different use cases: + +1. **Default (unnamed) service** - For ephemeral, short-lived data +2. **Persistent cache** (keyed service: `"persistent"`) - For longer-lived data with structured TTL + +### When to Use + +**Default `IDistributedCache`**: + +- **Legacy code** already using `IDistributedCache` (consider migrating to `ExtendedCache`) +- **Third-party integrations** requiring `IDistributedCache` interface +- **ASP.NET Core session storage** (framework dependency) +- You have **specific requirements** that ExtendedCache doesn't support + +> **Note**: For new code, prefer `ExtendedCache` over default `IDistributedCache`. ExtendedCache can be configured with `EnableDistributedCache = false` to use memory-only caching with the same performance as raw `IMemoryCache`, while still providing stampede protection, fail-safe, and eager refresh. + +**Persistent cache** (keyed service: `"persistent"`): + +- **Critical data where memory loss would impact users** (refresh tokens, consent grants) +- **Long-lived structured data** with automatic TTL (days to weeks) +- **Long-lived OAuth/OIDC grants** that must survive application restarts +- **Payment intents** or workflow state that spans multiple requests +- Data requiring **automatic expiration** without manual cleanup +- **Large cache datasets** that benefit from external storage (e.g., thousands of refresh tokens) + +### Pros + +✅ **Standard ASP.NET Core interface**: Widely understood, well-documented + +✅ **Multiple backend support**: Redis, SQL Server, Entity Framework, Cosmos DB + +✅ **Automatic backend selection**: Picks the right backend based on configuration + +✅ **Simple API**: Just `Get`, `Set`, `Remove`, `Refresh` + +✅ **Minimal overhead**: No additional layers beyond the backend + +✅ **Keyed services**: Separate configurations for different use cases + +### Cons + +❌ **No stampede protection**: Multiple requests = multiple backend calls + +❌ **No fail-safe mode**: Backend unavailable = cache miss + +❌ **No backplane**: Manual cache invalidation across instances + +❌ **Manual serialization**: You handle JSON serialization (or use helpers) + +❌ **Manual TTL management** (default service): Must track expiration manually + +### Example Usage: Default (Ephemeral Data) + +#### 1. Registration (already done in Api, Admin, Billing, Events, EventsProcessor, Identity, and Notifications Startup.cs files): + +```csharp +services.AddDistributedCache(globalSettings); +``` + +#### 2. Inject and use for short-lived tokens: + +```csharp +public class TwoFactorService +{ + private readonly IDistributedCache _cache; + + public TwoFactorService(IDistributedCache cache) + { + _cache = cache; + } + + public async Task GetEmailTokenAsync(Guid userId) + { + var key = $"email-2fa:{userId}"; + var cached = await _cache.GetStringAsync(key); + return cached; + } + + public async Task SetEmailTokenAsync(Guid userId, string token) + { + var key = $"email-2fa:{userId}"; + await _cache.SetStringAsync(key, token, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5) + }); + } +} +``` + +#### 3. Using JSON helpers: + +```csharp +using Bit.Core.Utilities; + +public async Task GetDataAsync(string key) +{ + return await _cache.TryGetValue(key); +} + +public async Task SetDataAsync(string key, MyData data) +{ + await _cache.SetAsync(key, data, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30) + }); +} +``` + +### Example Usage: Persistent (Long-Lived Data) + +The persistent cache is accessed via keyed service injection and is optimized for long-lived structured data with automatic TTL support. + +#### Specific Example: Payment Workflow State + +The persistent `IDistributedCache` service is appropriate for workflow state that spans multiple requests and needs automatic TTL cleanup. + +```csharp +public class SetupIntentDistributedCache( + [FromKeyedServices("persistent")] IDistributedCache distributedCache) : ISetupIntentCache +{ + public async Task Set(Guid subscriberId, string setupIntentId) + { + // Bidirectional mapping for payment flow + var bySubscriberIdCacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + var bySetupIntentIdCacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}"; + + // Note: No explicit TTL set here. Cosmos DB uses container-level TTL for automatic cleanup. + // In cloud, Cosmos TTL handles expiration. In self-hosted, the cache backend manages TTL. + await Task.WhenAll( + distributedCache.SetStringAsync(bySubscriberIdCacheKey, setupIntentId), + distributedCache.SetStringAsync(bySetupIntentIdCacheKey, subscriberId.ToString())); + } + + public async Task GetSetupIntentIdForSubscriber(Guid subscriberId) + { + var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + return await distributedCache.GetStringAsync(cacheKey); + } + + public async Task GetSubscriberIdForSetupIntent(string setupIntentId) + { + var cacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}"; + var value = await distributedCache.GetStringAsync(cacheKey); + if (string.IsNullOrEmpty(value) || !Guid.TryParse(value, out var subscriberId)) + { + return null; + } + return subscriberId; + } + + public async Task RemoveSetupIntentForSubscriber(Guid subscriberId) + { + var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + await distributedCache.RemoveAsync(cacheKey); + } +} +``` + +#### Specific Example: Long-Lived OAuth Grants + +Long-lived OAuth grants (refresh tokens, consent grants, device codes) use the persistent `IDistributedCache` in **cloud** and `IGrantRepository` as a **database fallback for self-hosted** when persistent cache is not configured: + +**Cloud (Bitwarden-hosted)**: + +- Uses persistent `IDistributedCache` directly (backed by Cosmos DB) +- Automatic TTL via Cosmos DB container-level TTL + +**Self-hosted**: + +- Uses `IGrantRepository` as a database fallback when persistent cache backend is not available +- Stores grants in `Grant` database table with automatic expiration + +**Grant type recommendations:** + +| Grant Type | Lifetime | Durability Requirement | Recommended Storage | Rationale | +| ------------------------ | ------------ | ---------------------- | ------------------- | ------------------------------------------------------------------------------------------- | +| SSO authorization codes | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | User can re-initiate SSO flow if code is lost; short lifetime limits exposure window | +| OIDC authorization codes | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | OAuth spec allows user to retry authorization; code is single-use and short-lived | +| PKCE code verifiers | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | Tied to authorization code lifecycle; can be regenerated if authorization is retried | +| Refresh tokens | Days-weeks | Must persist | Persistent cache | Losing these forces user re-authentication; critical for seamless user experience | +| Consent grants | Weeks-months | Must persist | Persistent cache | User shouldn't have to re-consent frequently; loss degrades UX and trust | +| Device codes | Days | Must persist | Persistent cache | Device flow is async; losing codes breaks pending device authorizations with no recovery UX | + +### Backend Configuration + +The backend is automatically selected based on configuration and service key: + +#### Default `IDistributedCache` (ephemeral) + +**Cloud (Bitwarden-hosted)**: + +- **Redis** only (always configured in cloud environments) + +**Self-hosted priority order**: + +1. **Redis** (if `GlobalSettings.DistributedCache.Redis.ConnectionString` is configured) +2. **SQL Server Cache table** (if database provider is SQL Server) +3. **Entity Framework Cache table** (for PostgreSQL, MySQL, SQLite) + +#### Persistent cache (keyed service: `"persistent"`) + +**Cloud (Bitwarden-hosted)**: + +1. **Cosmos DB** (if `GlobalSettings.DistributedCache.Cosmos.ConnectionString` is configured) + - Database: `cache` + - Container: `default` +2. **Falls back to Redis** + +**Self-hosted priority order**: + +1. **Redis** (if configured) +2. **SQL Server Cache table** (if database provider is SQL Server) +3. **Entity Framework Cache table** (for PostgreSQL, MySQL, SQLite) + +### Backend Details + +#### Redis + +```csharp +services.AddStackExchangeRedisCache(options => +{ + options.Configuration = globalSettings.DistributedCache.Redis.ConnectionString; +}); +``` + +**Used for**: Cloud (always), self-hosted (if configured) + +- **Pros**: Fast, horizontally scalable, battle-tested +- **Cons**: Additional infrastructure dependency (self-hosted only) +- **TTL**: Via `AbsoluteExpiration` in cache entry options + +#### SQL Server Cache Table (Self-hosted only) + +```csharp +services.AddDistributedSqlServerCache(options => +{ + options.ConnectionString = globalSettings.SqlServer.ConnectionString; + options.SchemaName = "dbo"; + options.TableName = "Cache"; +}); +``` + +**Used for**: Self-hosted deployments without Redis + +- **Pros**: No additional infrastructure, works with existing database +- **Cons**: Slower than Redis, adds load to database, less scalable +- **TTL**: Via `ExpiresAtTime` and `AbsoluteExpiration` columns + +#### Entity Framework Cache (Self-hosted only) + +```csharp +services.AddSingleton(); +``` + +**Used for**: Self-hosted deployments with PostgreSQL, MySQL, or SQLite + +- **Pros**: Works with any EF-supported database (PostgreSQL, MySQL, SQLite) +- **Cons**: Slower than Redis, requires periodic expiration scanning, adds DB load + +**Features**: + +- Thread-safe operations with mutex locks +- Automatic expiration scanning every 30 minutes +- Sliding and absolute expiration support +- Provider-specific duplicate key handling + +**TTL**: Via `ExpiresAtTime` and `AbsoluteExpiration` columns with background scanning + +#### Cosmos DB (Cloud only, persistent cache) + +```csharp +services.AddKeyedSingleton("persistent", (provider, _) => +{ + return new CosmosCache(new CosmosCacheOptions + { + DatabaseName = "cache", + ContainerName = "default", + ClientBuilder = cosmosClientBuilder + }); +}); +``` + +**Used for**: Cloud persistent keyed service only + +- **Pros**: Globally distributed, automatic TTL support via container-level TTL, optimized for long-lived data +- **Cons**: Cloud-only, higher latency than Redis + +**TTL**: Cosmos DB container-level TTL (automatic cleanup, no scanning required) + +### Comparison: Default vs Persistent + +| Characteristic | Default | Persistent cache (`"persistent"`) | +| ----------------------- | ------------------------------ | ---------------------------------------------- | +| **Primary Use Case** | Ephemeral tokens, session data | Long-lived grants, workflow state | +| **Typical TTL** | 5-15 minutes | Hours to weeks | +| **User Impact if Lost** | Low (user can retry) | High (forces re-auth, interrupts workflows) | +| **Scale Consideration** | Small datasets | Large/growing datasets (thousands to millions) | +| **Cloud Backend** | Redis | Cosmos DB → Redis | +| **Self-Hosted Backend** | Redis → SQL → EF | Redis → SQL → EF | +| **Automatic Cleanup** | Manual expiration | Automatic TTL (Cosmos) | +| **Data Structure** | Simple key-value | Supports structured data | +| **Example** | 2FA codes, TOTP tokens | Refresh tokens, payment intents | + +### Choosing Default vs Persistent + +**Use Default when**: + +- Data lifetime < 15 minutes +- Ephemeral authentication tokens +- Simple key-value pairs +- Cost optimization is important +- Data loss on restart is acceptable + +**Use Persistent when**: + +- **Data loss would have user impact** (e.g., losing refresh tokens forces re-authentication) +- Data lifetime > 15 minutes +- **Cache size is large or growing** (thousands of items that exceed memory constraints) +- Structured data with relationships +- Automatic TTL cleanup is required +- Data must survive restarts and deployments +- Query capabilities are needed (via Cosmos DB) + +### When NOT to Use + +- **New general-purpose caching** - Use `ExtendedCache` instead for stampede protection, fail-safe, and backplane support +- **Organization/Provider abilities** - Use `ExtendedCache` with preloading pattern (see example above) +- **Short-lived ephemeral data** without persistence requirements - Use `ExtendedCache` (simpler, more features) + +--- + +## `IApplicationCacheService` (Deprecated) + +> **⚠️ Deprecated**: This service is being phased out in favor of `ExtendedCache`. New code should use `ExtendedCache` with the preloading pattern shown in the [Organization/Provider Abilities example](#specific-example-organizationprovider-abilities) above. + +### Background + +`IApplicationCacheService` was a **highly domain-specific caching service** built for Bitwarden organization and provider abilities. It used in-memory cache with Azure Service Bus for cross-instance invalidation. + +**Why it's being replaced:** + +- **Infrastructure complexity**: Required both Redis and Azure Service Bus +- **Limited applicability**: Only worked for org/provider abilities +- **Maintenance burden**: Custom implementation instead of leveraging standard caching primitives +- **Better alternative exists**: `ExtendedCache` with Redis backplane provides the same functionality with simpler infrastructure + +### Migration Path + +**Old approach** (IApplicationCacheService): + +- In-memory cache with periodic refresh +- Azure Service Bus for cross-instance invalidation +- Custom implementation for each domain + +**New approach** (ExtendedCache): + +- Memory + Redis distributed cache with backplane +- Eager refresh for automatic background updates +- Fail-safe mode for resilience +- Standard FusionCache API +- One Redis instance instead of Redis + Service Bus + +See the [Organization/Provider Abilities example](#specific-example-organizationprovider-abilities) for the recommended migration pattern. + +### When NOT to Use + +❌ **Do not use for new code** - Use `ExtendedCache` instead + +For existing code using `IApplicationCacheService`, plan migration to `ExtendedCache` using the pattern shown above. + +--- + +## Specialized In-Memory Cache + +> **Recommendation**: In most cases, use `ExtendedCache` with `EnableDistributedCache = false` instead of implementing a specialized in-memory cache. ExtendedCache provides the same memory-only performance with built-in stampede protection, eager refresh, and fail-safe capabilities. + +### When to Use + +Use a specialized in-memory cache only when: + +- **ExtendedCache's API doesn't fit** your specific use case +- **Custom eviction logic** is required beyond TTL-based expiration +- **Non-standard data structures** (e.g., priority queues, LRU with custom scoring) +- **Direct memory access patterns** that bypass serialization entirely + +For general high-performance caching, prefer `ExtendedCache` with memory-only mode. + +### Pros + +✅ **Maximum performance**: No serialization, no network calls, no locking overhead + +✅ **Simple implementation**: Just a `Dictionary` or `ConcurrentDictionary` + +✅ **Zero infrastructure**: No Redis, no database, no additional dependencies + +### Cons + +❌ **No horizontal scaling**: Each instance has separate cache state + +❌ **Manual invalidation**: No built-in cache invalidation mechanism + +❌ **Manual TTL**: You implement expiration logic + +❌ **Memory pressure**: Large datasets can cause GC issues + +### Example Implementation + +#### Simple in-memory cache: + +```csharp +public class MyFeatureCache +{ + private readonly ConcurrentDictionary> _cache = new(); + private readonly TimeSpan _defaultExpiration = TimeSpan.FromMinutes(30); + + public MyData GetOrAdd(string key, Func factory) + { + var entry = _cache.GetOrAdd(key, _ => new CacheEntry + { + Value = factory(), + ExpiresAt = DateTime.UtcNow + _defaultExpiration + }); + + // WARNING: This implementation has a race condition. Multiple threads detecting + // expiration simultaneously may each call TryRemove and then recursively call + // GetOrAdd, potentially causing the factory to execute multiple times. For + // production use cases requiring thread-safe expiration, consider using + // IMemoryCache with GetOrCreateAsync or ExtendedCache with stampede protection. + if (entry.ExpiresAt < DateTime.UtcNow) + { + _cache.TryRemove(key, out _); + return GetOrAdd(key, factory); + } + + return entry.Value; + } + + private class CacheEntry + { + public T Value { get; set; } + public DateTime ExpiresAt { get; set; } + } +} +``` + +#### Using `IMemoryCache`: + +```csharp +public class MyService +{ + private readonly IMemoryCache _memoryCache; + + public MyService(IMemoryCache memoryCache) + { + _memoryCache = memoryCache; + } + + public async Task GetDataAsync(string key) + { + return await _memoryCache.GetOrCreateAsync(key, async entry => + { + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30); + entry.SetPriority(CacheItemPriority.High); + + return await _repository.GetDataAsync(key); + }); + } +} +``` + +### When NOT to Use + +- **Most general-purpose caching** - Use `ExtendedCache` with memory-only mode instead +- **Data requiring stampede protection** - Use `ExtendedCache` +- **Multi-instance deployments** requiring consistency - Use `ExtendedCache` with Redis +- **Long-lived OAuth grants** - Use persistent `IDistributedCache` + +> **Important**: Before implementing a custom in-memory cache, first try `ExtendedCache` with `EnableDistributedCache = false`. This gives you memory-only performance with automatic stampede protection, eager refresh, and fail-safe mode. + +--- + +## Backend Configuration + +### Configuration Priority + +The following table shows how different caching options resolve to storage backends based on configuration: + +| Cache Option | Cloud Backend | Self-Hosted Backend | Config Setting | +| -------------------------------------- | ------------------------- | --------------------------- | --------------------------------------------------------- | +| **ExtendedCache** | Redis → Memory | Redis → Memory | `GlobalSettings.DistributedCache.Redis.ConnectionString` | +| **IDistributedCache** (default) | Redis | Redis → SQL → EF | `GlobalSettings.DistributedCache.Redis.ConnectionString` | +| **IDistributedCache** (`"persistent"`) | Cosmos → Redis | Redis → SQL → EF | `GlobalSettings.DistributedCache.Cosmos.ConnectionString` | +| **OAuth Grants** (long-lived) | Persistent cache (Cosmos) | `IGrantRepository` (SQL/EF) | Various (see above) | + +### Redis Configuration + +**Cloud (Bitwarden-hosted)**: + +```json +{ + "GlobalSettings": { + "DistributedCache": { + "Redis": { + "ConnectionString": "redis.example.com:6379,ssl=true,password=..." + } + } + } +} +``` + +**Self-hosted** (`appsettings.json`): + +```json +{ + "globalSettings": { + "distributedCache": { + "redis": { + "connectionString": "localhost:6379" + } + } + } +} +``` + +### Cosmos DB Configuration + +**Persistent `IDistributedCache`** (cloud only): + +```json +{ + "GlobalSettings": { + "DistributedCache": { + "Cosmos": { + "ConnectionString": "AccountEndpoint=https://...;AccountKey=..." + } + } + } +} +``` + +- Database: `cache` +- Container: `default` +- Used for long-lived grants in cloud deployments + +### SQL Server Cache + +**Automatic configuration** (if SQL Server is database provider): + +```json +{ + "globalSettings": { + "sqlServer": { + "connectionString": "Server=...;Database=...;User Id=...;Password=..." + } + } +} +``` + +- Schema: `dbo` +- Table: `Cache` +- Migrations: Applied automatically + +### Entity Framework Cache + +**Automatic fallback** for PostgreSQL, MySQL, SQLite: + +No additional configuration required. Uses existing database connection. + +- Table: `Cache` +- Migrations: Applied automatically + +--- + +## Performance Considerations + +### Performance Characteristics + +| Backend | Read Latency | Write Latency | Throughput | +| -------------------- | ------------ | ------------- | ------------- | +| **Memory** | <1ms | <1ms | >100K req/s | +| **Redis** | 1-5ms | 1-5ms | 10K-50K req/s | +| **SQL Server** | 5-20ms | 10-50ms | 1K-5K req/s | +| **Entity Framework** | 5-20ms | 10-50ms | 1K-5K req/s | +| **Cosmos DB** | 5-15ms | 5-15ms | 10K+ req/s | + +**Note**: Latencies represent typical p95 values in production environments. Redis latencies assume same-datacenter deployment and include serialization overhead. Actual performance varies based on network topology, data size, and load. + +### Recommendations + +**For high-frequency reads (>1K req/s)**: + +1. `ExtendedCache` with Redis (cloud) +2. `ExtendedCache` memory-only (self-hosted, single instance) +3. Specialized in-memory cache (extreme performance requirements) + +**For moderate traffic (100-1K req/s)**: + +1. `ExtendedCache` with shared Redis +2. `IDistributedCache` with SQL Server cache + +**For low traffic (<100 req/s)**: + +1. `IDistributedCache` with SQL Server / EF cache +2. `ExtendedCache` memory-only + +--- + +## Testing Caches + +### Unit Testing + +**`ExtendedCache`**: + +```csharp +[Fact] +public async Task TestCacheHit() +{ + var services = new ServiceCollection(); + services.AddMemoryCache(); + services.AddExtendedCache("TestCache", new GlobalSettings + { + DistributedCache = new GlobalSettings.DistributedCacheSettings() + }); + + var provider = services.BuildServiceProvider(); + var cache = provider.GetRequiredKeyedService("TestCache"); + + await cache.SetAsync("key", "value"); + var result = await cache.GetOrDefaultAsync("key"); + + Assert.Equal("value", result); +} +``` + +**`IDistributedCache`**: + +```csharp +[Fact] +public async Task TestDistributedCache() +{ + var cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + + await cache.SetStringAsync("key", "value"); + var result = await cache.GetStringAsync("key"); + + Assert.Equal("value", result); +} +``` + +### Integration Testing + +**Example**: + +```csharp +[DatabaseTheory, DatabaseData] +public async Task Cache_ExpirationScanning_RemovesExpiredItems(IDistributedCache cache) +{ + // Set item with 1-second expiration + await cache.SetAsync("key", Encoding.UTF8.GetBytes("value"), new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromSeconds(1) + }); + + // Wait for expiration + await Task.Delay(TimeSpan.FromSeconds(2)); + + // Trigger expiration scan + var entityCache = cache as EntityFrameworkCache; + await entityCache.ScanForExpiredItemsAsync(); + + // Verify item is removed + var result = await cache.GetAsync("key"); + Assert.Null(result); +} +``` + +--- + +## Migration Examples + +Examples of migrating from one caching option to another: + +### From `IDistributedCache` → `ExtendedCache` + +**Before**: + +```csharp +// Registration +services.AddDistributedCache(globalSettings); + +// Constructor +public MyService(IDistributedCache cache, IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + var data = await _cache.TryGetValue(key); + if (data == null) + { + data = await _repository.GetAsync(key); + await _cache.SetAsync(key, data, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30) + }); + } + return data; +} +``` + +**After**: + +```csharp +// Registration +services.AddDistributedCache(globalSettings); +services.AddExtendedCache("MyFeature", globalSettings); + +// Constructor +public MyService( + [FromKeyedServices("MyFeature")] IFusionCache cache, + IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + return await _cache.GetOrSetAsync( + key, + async _ => await _repository.GetAsync(key), + options => options.SetDuration(TimeSpan.FromMinutes(30)) + ); +} +``` + +### From In-Memory → `ExtendedCache` + +**Before**: + +```csharp +// Field +private readonly ConcurrentDictionary _cache = new(); +private readonly IRepository _repository; + +// Constructor +public MyService(IRepository repository) +{ + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + if (_cache.TryGetValue(key, out var cached)) + { + return cached; + } + + var data = await _repository.GetAsync(key); + _cache.TryAdd(key, data); + return data; +} +``` + +**After**: + +```csharp +// Registration +services.AddExtendedCache("MyFeature", globalSettings); + +// Constructor +public MyService( + [FromKeyedServices("MyFeature")] IFusionCache cache, + IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + return await _cache.GetOrSetAsync( + key, + async _ => await _repository.GetAsync(key) + ); +} +``` diff --git a/src/Core/Utilities/EmailValidation.cs b/src/Core/Utilities/EmailValidation.cs index f6832945af..10892f85c4 100644 --- a/src/Core/Utilities/EmailValidation.cs +++ b/src/Core/Utilities/EmailValidation.cs @@ -1,4 +1,6 @@ -using System.Text.RegularExpressions; +using System.Net.Mail; +using System.Text.RegularExpressions; +using Bit.Core.Exceptions; using MimeKit; namespace Bit.Core.Utilities; @@ -41,4 +43,22 @@ public static class EmailValidation return true; } + + /// + /// Extracts the domain portion from an email address and normalizes it to lowercase. + /// + /// The email address to extract the domain from. + /// The domain portion of the email address in lowercase (e.g., "example.com"). + /// Thrown when the email address format is invalid. + public static string GetDomain(string email) + { + try + { + return new MailAddress(email).Host.ToLower(); + } + catch (Exception ex) when (ex is FormatException || ex is ArgumentException) + { + throw new BadRequestException("Invalid email address format."); + } + } } diff --git a/src/Core/Utilities/EventIntegrationsCacheConstants.cs b/src/Core/Utilities/EventIntegrationsCacheConstants.cs new file mode 100644 index 0000000000..19cc3f949c --- /dev/null +++ b/src/Core/Utilities/EventIntegrationsCacheConstants.cs @@ -0,0 +1,84 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; + +namespace Bit.Core.Utilities; + +/// +/// Provides cache key generation helpers and cache name constants for event integration–related entities. +/// +public static class EventIntegrationsCacheConstants +{ + /// + /// The base cache name used for storing event integration data. + /// + public const string CacheName = "EventIntegrations"; + + /// + /// Duration TimeSpan for adding OrganizationIntegrationConfigurationDetails to the cache. + /// + public static readonly TimeSpan DurationForOrganizationIntegrationConfigurationDetails = TimeSpan.FromDays(1); + + /// + /// Builds a deterministic cache key for a . + /// + /// The unique identifier of the group. + /// + /// A cache key for this Group. + /// + public static string BuildCacheKeyForGroup(Guid groupId) => + $"Group:{groupId:N}"; + + /// + /// Builds a deterministic cache key for an . + /// + /// The unique identifier of the organization. + /// + /// A cache key for the Organization. + /// + public static string BuildCacheKeyForOrganization(Guid organizationId) => + $"Organization:{organizationId:N}"; + + /// + /// Builds a deterministic cache key for an organization user . + /// + /// The unique identifier of the organization to which the user belongs. + /// The unique identifier of the user. + /// + /// A cache key for the user. + /// + public static string BuildCacheKeyForOrganizationUser(Guid organizationId, Guid userId) => + $"OrganizationUserUserDetails:{organizationId:N}:{userId:N}"; + + /// + /// Builds a deterministic cache key for an organization's integration configuration details + /// . + /// + /// The unique identifier of the organization. + /// The of the integration. + /// The specific of the event configured. + /// + /// A cache key for the configuration details. + /// + public static string BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + Guid organizationId, + IntegrationType integrationType, + EventType eventType + ) => $"OrganizationIntegrationConfigurationDetails:{organizationId:N}:{integrationType}:{eventType}"; + + /// + /// Builds a deterministic tag for tagging an organization's integration configuration details. This tag is then + /// used to tag all of the that result from this + /// integration, which allows us to remove all relevant entries when an integration is changed or removed. + /// + /// The unique identifier of the organization to which the user belongs. + /// The of the integration. + /// + /// A cache tag to use for the configuration details. + /// + public static string BuildCacheTagForOrganizationIntegration( + Guid organizationId, + IntegrationType integrationType + ) => $"OrganizationIntegration:{organizationId:N}:{integrationType}"; +} diff --git a/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs b/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs new file mode 100644 index 0000000000..f287f64e54 --- /dev/null +++ b/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs @@ -0,0 +1,186 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.StackExchangeRedis; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; +using ZiggyCreatures.Caching.Fusion; +using ZiggyCreatures.Caching.Fusion.Backplane; +using ZiggyCreatures.Caching.Fusion.Backplane.StackExchangeRedis; +using ZiggyCreatures.Caching.Fusion.Serialization.SystemTextJson; + +namespace Microsoft.Extensions.DependencyInjection; + +public static class ExtendedCacheServiceCollectionExtensions +{ + /// + /// Adds a new, named Fusion Cache to the service + /// collection. If an existing cache of the same name is found, it will do nothing.
    + ///
    + /// Note: When re-using an existing distributed cache, it is expected to call this method after calling + /// services.AddDistributedCache(globalSettings)
    This ensures that DI correctly finds + /// and re-uses the shared distributed cache infrastructure.
    + ///
    + /// Backplane: Cross-instance cache invalidation is only available when using Redis. + /// Non-Redis distributed caches operate with eventual consistency across multiple instances. + ///
    + public static IServiceCollection AddExtendedCache( + this IServiceCollection services, + string cacheName, + GlobalSettings globalSettings, + GlobalSettings.ExtendedCacheSettings? settings = null) + { + settings ??= globalSettings.DistributedCache.DefaultExtendedCache; + if (settings is null || string.IsNullOrEmpty(cacheName)) + { + return services; + } + + // If a cache already exists with this key, do nothing + if (services.Any(s => s.ServiceType == typeof(IFusionCache) && + s.ServiceKey?.Equals(cacheName) == true)) + { + return services; + } + + if (services.All(s => s.ServiceType != typeof(FusionCacheSystemTextJsonSerializer))) + { + services.AddFusionCacheSystemTextJsonSerializer(); + } + var fusionCacheBuilder = services + .AddFusionCache(cacheName) + .WithCacheKeyPrefix($"{cacheName}:") + .AsKeyedServiceByCacheName() + .WithOptions(opt => + { + opt.DistributedCacheCircuitBreakerDuration = settings.DistributedCacheCircuitBreakerDuration; + }) + .WithDefaultEntryOptions(new FusionCacheEntryOptions + { + Duration = settings.Duration, + IsFailSafeEnabled = settings.IsFailSafeEnabled, + FailSafeMaxDuration = settings.FailSafeMaxDuration, + FailSafeThrottleDuration = settings.FailSafeThrottleDuration, + EagerRefreshThreshold = settings.EagerRefreshThreshold, + FactorySoftTimeout = settings.FactorySoftTimeout, + FactoryHardTimeout = settings.FactoryHardTimeout, + DistributedCacheSoftTimeout = settings.DistributedCacheSoftTimeout, + DistributedCacheHardTimeout = settings.DistributedCacheHardTimeout, + AllowBackgroundDistributedCacheOperations = settings.AllowBackgroundDistributedCacheOperations, + JitterMaxDuration = settings.JitterMaxDuration + }) + .WithRegisteredSerializer(); + + if (!settings.EnableDistributedCache) + return services; + + if (settings.UseSharedDistributedCache) + { + if (!CoreHelpers.SettingHasValue(globalSettings.DistributedCache.Redis.ConnectionString)) + { + // Using Shared Non-Redis Distributed Cache: + // 1. Assume IDistributedCache is already registered (e.g., Cosmos, SQL Server) + // 2. Backplane not supported (Redis-only feature, requires pub/sub) + + fusionCacheBuilder + .TryWithRegisteredDistributedCache(); + + return services; + } + + // Using Shared Redis, TryAdd and reuse all pieces (multiplexer, distributed cache and backplane) + + services.TryAddSingleton(sp => + CreateConnectionMultiplexer(sp, cacheName, globalSettings.DistributedCache.Redis.ConnectionString)); + + services.TryAddSingleton(sp => + { + var mux = sp.GetRequiredService(); + return new RedisCache(new RedisCacheOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + }); + + services.TryAddSingleton(sp => + { + var mux = sp.GetRequiredService(); + return new RedisBackplane(new RedisBackplaneOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + }); + + fusionCacheBuilder + .WithRegisteredDistributedCache() + .WithRegisteredBackplane(); + + return services; + } + + // Using keyed Distributed Cache. Create/Reuse all pieces as keyed services. + + if (!CoreHelpers.SettingHasValue(settings.Redis.ConnectionString)) + { + // Using Keyed Non-Redis Distributed Cache: + // 1. Assume IDistributedCache (e.g., Cosmos, SQL Server) is already registered with cacheName as key + // 2. Backplane not supported (Redis-only feature, requires pub/sub) + + fusionCacheBuilder + .TryWithRegisteredKeyedDistributedCache(serviceKey: cacheName); + + return services; + } + + // Using Keyed Redis: TryAdd and reuse all pieces (multiplexer, distributed cache and backplane) + + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => CreateConnectionMultiplexer(sp, cacheName, settings.Redis.ConnectionString) + ); + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => + { + var mux = sp.GetRequiredKeyedService(cacheName); + return new RedisCache(new RedisCacheOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + } + ); + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => + { + var mux = sp.GetRequiredKeyedService(cacheName); + return new RedisBackplane(new RedisBackplaneOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + } + ); + + fusionCacheBuilder + .WithRegisteredKeyedDistributedCacheByCacheName() + .WithRegisteredKeyedBackplaneByCacheName(); + + return services; + } + + private static ConnectionMultiplexer CreateConnectionMultiplexer(IServiceProvider sp, string cacheName, + string connectionString) + { + try + { + return ConnectionMultiplexer.Connect(connectionString); + } + catch (Exception ex) + { + var logger = sp.GetService(); + logger?.LogError(ex, "Failed to connect to Redis for cache {CacheName}", cacheName); + throw; + } + } +} diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index 54bd84df6f..b950e30d5d 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -1,165 +1,78 @@ -using System.Security.Cryptography.X509Certificates; -using Bit.Core.Settings; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -using Serilog; -using Serilog.Events; -using Serilog.Sinks.Syslog; namespace Bit.Core.Utilities; public static class LoggerFactoryExtensions { - public static void UseSerilog( - this IApplicationBuilder appBuilder, - IWebHostEnvironment env, - IHostApplicationLifetime applicationLifetime, - GlobalSettings globalSettings) + /// + /// + /// + /// + /// + public static IHostBuilder AddSerilogFileLogging(this IHostBuilder hostBuilder) { - if (env.IsDevelopment() && !globalSettings.EnableDevLogging) + return hostBuilder.ConfigureLogging((context, logging) => { - return; - } - - applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); - } - - public static ILoggingBuilder AddSerilog( - this ILoggingBuilder builder, - WebHostBuilderContext context, - Func? filter = null) - { - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); - - if (context.HostingEnvironment.IsDevelopment() && !globalSettings.EnableDevLogging) - { - return builder; - } - - bool inclusionPredicate(LogEvent e) - { - if (filter == null) + if (context.HostingEnvironment.IsDevelopment()) { - return true; + return; } - var eventId = e.Properties.TryGetValue("EventId", out var eventIdValue) ? eventIdValue.ToString() : null; - if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) + + // If they have begun using the new settings location, use that + if (!string.IsNullOrEmpty(context.Configuration["Logging:PathFormat"])) { - return true; - } - return filter(e, globalSettings); - } - - var logSentryWarning = false; - var logSyslogWarning = false; - - // Path format is the only required option for file logging, we will use that as - // the keystone for if they have configured the new location. - var newPathFormat = context.Configuration["Logging:PathFormat"]; - - var config = new LoggerConfiguration() - .MinimumLevel.Verbose() - .Enrich.FromLogContext() - .Filter.ByIncludingOnly(inclusionPredicate); - - if (CoreHelpers.SettingHasValue(globalSettings.Sentry.Dsn)) - { - config.WriteTo.Sentry(globalSettings.Sentry.Dsn) - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Syslog.Destination)) - { - logSyslogWarning = true; - // appending sitename to project name to allow easier identification in syslog. - var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; - if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) - { - config.WriteTo.LocalSyslog(appName); - } - else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) - { - // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. - int port = syslogAddress.Port >= 0 - ? syslogAddress.Port - : 514; - - if (syslogAddress.Scheme.Equals("udp")) - { - config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tcp")) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tls")) - { - if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - useTls: true, - certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, - globalSettings.Syslog.CertificateThumbprint)); - } - else - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - useTls: true, - certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, - globalSettings.Syslog?.CertificatePassword ?? string.Empty)); - } - } - } - } - else if (!string.IsNullOrEmpty(newPathFormat)) - { - // Use new location - builder.AddFile(context.Configuration.GetSection("Logging")); - } - else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) - { - if (globalSettings.LogRollBySizeLimit.HasValue) - { - var pathFormat = Path.Combine(globalSettings.LogDirectory, $"{globalSettings.ProjectName.ToLowerInvariant()}.log"); - if (globalSettings.LogDirectoryByProject) - { - pathFormat = Path.Combine(globalSettings.LogDirectory, globalSettings.ProjectName, "log.txt"); - } - config.WriteTo.File(pathFormat, rollOnFileSizeLimit: true, - fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); + logging.AddFile(context.Configuration.GetSection("Logging")); } else { - var pathFormat = Path.Combine(globalSettings.LogDirectory, $"{globalSettings.ProjectName.ToLowerInvariant()}_{{Date}}.log"); - if (globalSettings.LogDirectoryByProject) + var globalSettingsSection = context.Configuration.GetSection("GlobalSettings"); + var loggingOptions = new LegacyFileLoggingOptions(); + globalSettingsSection.Bind(loggingOptions); + + if (string.IsNullOrWhiteSpace(loggingOptions.LogDirectory)) { - pathFormat = Path.Combine(globalSettings.LogDirectory, globalSettings.ProjectName, "{Date}.txt"); + return; + } + + var projectName = loggingOptions.ProjectName + ?? context.HostingEnvironment.ApplicationName; + + if (loggingOptions.LogRollBySizeLimit.HasValue) + { + var pathFormat = loggingOptions.LogDirectoryByProject + ? Path.Combine(loggingOptions.LogDirectory, projectName, "log.txt") + : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}.log"); + + logging.AddFile( + pathFormat: pathFormat, + fileSizeLimitBytes: loggingOptions.LogRollBySizeLimit.Value + ); + } + else + { + var pathFormat = loggingOptions.LogDirectoryByProject + ? Path.Combine(loggingOptions.LogDirectory, projectName, "{Date}.txt") + : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}_{{Date}}.log"); + + logging.AddFile( + pathFormat: pathFormat + ); } - config.WriteTo.RollingFile(pathFormat); } - config - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } + }); + } - var serilog = config.CreateLogger(); - - if (logSentryWarning) - { - serilog.Warning("Sentry for logging has been deprecated. Read more: https://btwrdn.com/log-deprecation"); - } - - if (logSyslogWarning) - { - serilog.Warning("Syslog for logging has been deprecated. Read more: https://btwrdn.com/log-deprecation"); - } - - builder.AddSerilog(serilog); - - return builder; + /// + /// Our own proprietary options that we've always supported in `GlobalSettings` configuration section. + /// + private class LegacyFileLoggingOptions + { + public string? ProjectName { get; set; } + public string? LogDirectory { get; set; } = "/etc/bitwarden/logs"; + public bool LogDirectoryByProject { get; set; } = true; + public long? LogRollBySizeLimit { get; set; } } } diff --git a/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs b/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs new file mode 100644 index 0000000000..a8208844a8 --- /dev/null +++ b/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs @@ -0,0 +1,24 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.Extensions.Hosting; + +namespace Bit.Core.Utilities; + +/// +/// Authorization attribute that restricts controller/action access to Development and QA environments only. +/// Returns 404 Not Found in all other environments. +/// +public class RequireLowerEnvironmentAttribute() : TypeFilterAttribute(typeof(LowerEnvironmentFilter)) +{ + private class LowerEnvironmentFilter(IWebHostEnvironment environment) : IAuthorizationFilter + { + public void OnAuthorization(AuthorizationFilterContext context) + { + if (!environment.IsDevelopment() && !environment.IsEnvironment("QA")) + { + context.Result = new NotFoundResult(); + } + } + } +} diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 36c4a54ae4..f0fbd80c38 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -1,13 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using System.Collections.Immutable; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations.OrganizationUsers; -using Bit.Core.Models.StaticStore; namespace Bit.Core.Utilities; @@ -110,56 +104,7 @@ public static class StaticStore GlobalDomains.Add(GlobalEquivalentDomainsType.Atlassian, new List { "atlassian.com", "bitbucket.org", "trello.com", "statuspage.io", "atlassian.net", "jira.com" }); GlobalDomains.Add(GlobalEquivalentDomainsType.Pinterest, new List { "pinterest.com", "pinterest.com.au", "pinterest.cl", "pinterest.de", "pinterest.dk", "pinterest.es", "pinterest.fr", "pinterest.co.uk", "pinterest.jp", "pinterest.co.kr", "pinterest.nz", "pinterest.pt", "pinterest.se" }); #endregion - - Plans = new List - { - new EnterprisePlan(true), - new EnterprisePlan(false), - new TeamsStarterPlan(), - new TeamsPlan(true), - new TeamsPlan(false), - - new Enterprise2023Plan(true), - new Enterprise2023Plan(false), - new Enterprise2020Plan(true), - new Enterprise2020Plan(false), - new TeamsStarterPlan2023(), - new Teams2023Plan(true), - new Teams2023Plan(false), - new Teams2020Plan(true), - new Teams2020Plan(false), - new FamiliesPlan(), - new FreePlan(), - new CustomPlan(), - - new Enterprise2019Plan(true), - new Enterprise2019Plan(false), - new Teams2019Plan(true), - new Teams2019Plan(false), - new Families2019Plan(), - new Families2025Plan() - }.ToImmutableList(); } public static IDictionary> GlobalDomains { get; set; } - [Obsolete("Use PricingClient.ListPlans to retrieve all plans.")] - public static IEnumerable Plans { get; } - public static IEnumerable SponsoredPlans { get; set; } = new[] - { - new SponsoredPlan - { - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - SponsoredProductTierType = ProductTierType.Families, - SponsoringProductTierType = ProductTierType.Enterprise, - StripePlanId = "2021-family-for-enterprise-annually", - UsersCanSponsor = (OrganizationUserOrganizationDetails org) => - org.PlanType.GetProductTier() == ProductTierType.Enterprise, - } - }; - - [Obsolete("Use PricingClient.GetPlan to retrieve a plan.")] - public static Plan GetPlan(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType); - - public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => - SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); } diff --git a/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs b/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs index fbd553d772..bb3bafb230 100644 --- a/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs +++ b/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs @@ -14,7 +14,7 @@ public class NormalCipherPermissions throw new Exception("Cipher needs to belong to a user or an organization."); } - if (user.Id == cipherDetails.UserId) + if (cipherDetails.OrganizationId == null && user.Id == cipherDetails.UserId) { return true; } diff --git a/src/Core/Vault/Services/ICipherService.cs b/src/Core/Vault/Services/ICipherService.cs index 110d4b6ea4..765dae30c1 100644 --- a/src/Core/Vault/Services/ICipherService.cs +++ b/src/Core/Vault/Services/ICipherService.cs @@ -17,7 +17,7 @@ public interface ICipherService Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, - string attachmentId, Guid organizationShareId, DateTime? lastKnownRevisionDate = null); + string attachmentId, Guid organizationShareId); Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); @@ -34,7 +34,7 @@ public interface ICipherService Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false); Task> RestoreManyAsync(IEnumerable cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false); - Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId, DateTime? lastKnownRevisionDate = null); + Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); Task ValidateBulkCollectionAssignmentAsync(IEnumerable collectionIds, IEnumerable cipherIds, Guid userId); diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index 4e980f66b6..2085345b16 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -183,9 +183,8 @@ public class CipherService : ICipherService } } - public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment, DateTime? lastKnownRevisionDate = null) + public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) { - ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (attachment == null) { throw new BadRequestException("Cipher attachment does not exist"); @@ -290,11 +289,10 @@ public class CipherService : ICipherService } public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, string attachmentId, Guid organizationId, DateTime? lastKnownRevisionDate = null) + long requestLength, string attachmentId, Guid organizationId) { try { - ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (requestLength < 1) { throw new BadRequestException("No data to attach."); @@ -992,11 +990,6 @@ public class CipherService : ICipherService throw new BadRequestException("One or more ciphers do not belong to you."); } - if (cipher.ArchivedDate.HasValue) - { - throw new BadRequestException("Cipher cannot be shared with organization because it is archived."); - } - var attachments = cipher.GetAttachments(); var hasAttachments = attachments?.Any() ?? false; var org = await _organizationRepository.GetByIdAsync(organizationId); diff --git a/src/Events/Program.cs b/src/Events/Program.cs index 967e94ed83..1a00549005 100644 --- a/src/Events/Program.cs +++ b/src/Events/Program.cs @@ -12,26 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.EventsSettings.IdentityToken; - } - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.EventsSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Events/Startup.cs b/src/Events/Startup.cs index cfe177aa2c..75301cf08c 100644 --- a/src/Events/Startup.cs +++ b/src/Events/Startup.cs @@ -84,17 +84,16 @@ public class Startup services.AddHostedService(); } + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddRabbitMqListeners(globalSettings); } public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Events/appsettings.json b/src/Events/appsettings.json index e72b978f2f..41637c8549 100644 --- a/src/Events/appsettings.json +++ b/src/Events/appsettings.json @@ -14,9 +14,6 @@ "events": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "amazon": { "accessKeyId": "SECRET", "accessKeySecret": "SECRET", diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index c6f5afbfdd..c4c02e32d2 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -6,6 +6,7 @@ using Azure.Storage.Queues; using Bit.Core; using Bit.Core.Models.Data; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; namespace Bit.EventsProcessor; @@ -13,7 +14,7 @@ namespace Bit.EventsProcessor; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IConfiguration _configuration; + private readonly GlobalSettings _globalSettings; private Task _executingTask; private CancellationTokenSource _cts; @@ -22,10 +23,10 @@ public class AzureQueueHostedService : IHostedService, IDisposable public AzureQueueHostedService( ILogger logger, - IConfiguration configuration) + GlobalSettings globalSettings) { _logger = logger; - _configuration = configuration; + _globalSettings = globalSettings; } public Task StartAsync(CancellationToken cancellationToken) @@ -56,15 +57,18 @@ public class AzureQueueHostedService : IHostedService, IDisposable private async Task ExecuteAsync(CancellationToken cancellationToken) { - var storageConnectionString = _configuration["azureStorageConnectionString"]; - if (string.IsNullOrWhiteSpace(storageConnectionString)) + var storageConnectionString = _globalSettings.Events.ConnectionString; + var queueName = _globalSettings.Events.QueueName; + if (string.IsNullOrWhiteSpace(storageConnectionString) || + string.IsNullOrWhiteSpace(queueName)) { + _logger.LogInformation("Azure Queue Hosted Service is disabled. Missing connection string or queue name."); return; } var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); _eventWriteService = new RepositoryEventWriteService(repo); - _queueClient = new QueueClient(storageConnectionString, "event"); + _queueClient = new QueueClient(storageConnectionString, queueName); while (!cancellationToken.IsCancellationRequested) { diff --git a/src/EventsProcessor/Program.cs b/src/EventsProcessor/Program.cs index 9b7a31e6f4..e4f4ac90d1 100644 --- a/src/EventsProcessor/Program.cs +++ b/src/EventsProcessor/Program.cs @@ -11,9 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => e.Level >= globalSettings.MinLogLevel.EventsProcessorSettings.Default)); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index 67676a8afc..888dda43a1 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -1,5 +1,4 @@ using System.Globalization; -using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.IdentityModel.Logging; @@ -32,19 +31,15 @@ public class Startup // Repositories services.AddDatabaseRepositories(globalSettings); - // Hosted Services + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddAzureServiceBusListeners(globalSettings); services.AddHostedService(); } - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + public void Configure(IApplicationBuilder app) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); app.UseRouting(); diff --git a/src/Icons/Program.cs b/src/Icons/Program.cs index 237096b0b1..80c1b5728e 100644 --- a/src/Icons/Program.cs +++ b/src/Icons/Program.cs @@ -11,9 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => e.Level >= globalSettings.MinLogLevel.IconsSettings.Default)); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Icons/Startup.cs b/src/Icons/Startup.cs index 2602dd6264..5d9b5e5a30 100644 --- a/src/Icons/Startup.cs +++ b/src/Icons/Startup.cs @@ -60,11 +60,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index cc146800af..b7d4342c1b 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -109,8 +109,12 @@ public class AccountsController : Controller [HttpPost("register/send-verification-email")] public async Task PostRegisterSendVerificationEmail([FromBody] RegisterSendVerificationEmailRequestModel model) { + // Only pass fromMarketing if the feature flag is enabled + var isMarketingFeatureEnabled = _featureService.IsEnabled(FeatureFlagKeys.MarketingInitiatedPremiumFlow); + var fromMarketing = isMarketingFeatureEnabled ? model.FromMarketing : null; + var token = await _sendVerificationEmailForRegistrationCommand.Run(model.Email, model.Name, - model.ReceiveMarketingEmails); + model.ReceiveMarketingEmails, fromMarketing); if (token != null) { @@ -195,16 +199,35 @@ public class AccountsController : Controller throw new BadRequestException(ModelState); } - // Moved from API, If you modify this endpoint, please update API as well. Self hosted installs still use the API endpoints. [HttpPost("prelogin")] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) + [Obsolete("Migrating to use a more descriptive endpoint that would support different types of prelogins. " + + "Use prelogin/password instead. This endpoint has no EOL at the time of writing.")] + public async Task PostPrelogin([FromBody] PasswordPreloginRequestModel model) + { + // Same as PostPasswordPrelogin to maintain compatibility. Do not make changes in this function body, + // only make changes in MakePasswordPreloginCall + return await MakePasswordPreloginCall(model); + } + + // There are two functions done this way because the open api docs that get generated in our build pipeline + // cannot handle two of the same post attributes on the same function call. That is why there is a + // PostPrelogin and the more appropriate PostPasswordPrelogin. + [HttpPost("prelogin/password")] + public async Task PostPasswordPrelogin([FromBody] PasswordPreloginRequestModel model) + { + // Same as PostPrelogin to maintain backwards compatibility. Do not make changes in this function body, + // only make changes in MakePasswordPreloginCall + return await MakePasswordPreloginCall(model); + } + + private async Task MakePasswordPreloginCall(PasswordPreloginRequestModel model) { var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); if (kdfInformation == null) { kdfInformation = GetDefaultKdf(model.Email); } - return new PreloginResponseModel(kdfInformation); + return new PasswordPreloginResponseModel(kdfInformation, model.Email); } [HttpGet("webauthn/assertion-options")] @@ -228,19 +251,17 @@ public class AccountsController : Controller { return _defaultKdfResults[0]; } - else - { - // Compute the HMAC hash of the email - var hmacMessage = Encoding.UTF8.GetBytes(email.Trim().ToLowerInvariant()); - using var hmac = new System.Security.Cryptography.HMACSHA256(_defaultKdfHmacKey); - var hmacHash = hmac.ComputeHash(hmacMessage); - // Convert the hash to a number - var hashHex = BitConverter.ToString(hmacHash).Replace("-", string.Empty).ToLowerInvariant(); - var hashFirst8Bytes = hashHex.Substring(0, 16); - var hashNumber = long.Parse(hashFirst8Bytes, System.Globalization.NumberStyles.HexNumber); - // Find the default KDF value for this hash number - var hashIndex = (int)(Math.Abs(hashNumber) % _defaultKdfResults.Count); - return _defaultKdfResults[hashIndex]; - } + + // Compute the HMAC hash of the email + var hmacMessage = Encoding.UTF8.GetBytes(email.Trim().ToLowerInvariant()); + using var hmac = new System.Security.Cryptography.HMACSHA256(_defaultKdfHmacKey); + var hmacHash = hmac.ComputeHash(hmacMessage); + // Convert the hash to a number + var hashHex = BitConverter.ToString(hmacHash).Replace("-", string.Empty).ToLowerInvariant(); + var hashFirst8Bytes = hashHex.Substring(0, 16); + var hashNumber = long.Parse(hashFirst8Bytes, System.Globalization.NumberStyles.HexNumber); + // Find the default KDF value for this hash number + var hashIndex = (int)(Math.Abs(hashNumber) % _defaultKdfResults.Count); + return _defaultKdfResults[hashIndex]; } } diff --git a/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs new file mode 100644 index 0000000000..4787125045 --- /dev/null +++ b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs @@ -0,0 +1,30 @@ +namespace Bit.Identity.IdentityServer.RequestValidationConstants; + +public static class CustomResponseConstants +{ + public static class ResponseKeys + { + /// + /// Identifies the error model returned in the custom response when an error occurs. + /// + public static string ErrorModel => "ErrorModel"; + /// + /// This Key is used when a user is in a single organization that requires SSO authentication. The identifier + /// is used by the client to speed the redirection to the correct IdP for the user's organization. + /// + public static string SsoOrganizationIdentifier => "SsoOrganizationIdentifier"; + } +} + +public static class SsoConstants +{ + /// + /// These are messages and errors we return when SSO Validation is unsuccessful + /// + public static class RequestErrors + { + public static string SsoRequired => "sso_required"; + public static string SsoRequiredDescription => "Sso authentication is required."; + public static string SsoTwoFactorRecoveryDescription => "Two-factor recovery has been performed. SSO authentication is required."; + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index 224c7a1866..0bdf1d89c2 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -34,6 +34,7 @@ public abstract class BaseRequestValidator where T : class private readonly IEventService _eventService; private readonly IDeviceValidator _deviceValidator; private readonly ITwoFactorAuthenticationValidator _twoFactorAuthenticationValidator; + private readonly ISsoRequestValidator _ssoRequestValidator; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; @@ -43,7 +44,7 @@ public abstract class BaseRequestValidator where T : class protected ICurrentContext CurrentContext { get; } protected IPolicyService PolicyService { get; } - protected IFeatureService FeatureService { get; } + protected IFeatureService _featureService { get; } protected ISsoConfigRepository SsoConfigRepository { get; } protected IUserService _userService { get; } protected IUserDecryptionOptionsBuilder UserDecryptionOptionsBuilder { get; } @@ -56,6 +57,7 @@ public abstract class BaseRequestValidator where T : class IEventService eventService, IDeviceValidator deviceValidator, ITwoFactorAuthenticationValidator twoFactorAuthenticationValidator, + ISsoRequestValidator ssoRequestValidator, IOrganizationUserRepository organizationUserRepository, ILogger logger, ICurrentContext currentContext, @@ -76,13 +78,14 @@ public abstract class BaseRequestValidator where T : class _eventService = eventService; _deviceValidator = deviceValidator; _twoFactorAuthenticationValidator = twoFactorAuthenticationValidator; + _ssoRequestValidator = ssoRequestValidator; _organizationUserRepository = organizationUserRepository; _logger = logger; CurrentContext = currentContext; _globalSettings = globalSettings; PolicyService = policyService; _userRepository = userRepository; - FeatureService = featureService; + _featureService = featureService; SsoConfigRepository = ssoConfigRepository; UserDecryptionOptionsBuilder = userDecryptionOptionsBuilder; PolicyRequirementQuery = policyRequirementQuery; @@ -94,141 +97,16 @@ public abstract class BaseRequestValidator where T : class protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - if (FeatureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) + var validators = DetermineValidationOrder(context, request, validatorContext); + var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); + if (!allValidationSchemesSuccessful) { - var validators = DetermineValidationOrder(context, request, validatorContext); - var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); - if (!allValidationSchemesSuccessful) - { - // Each validation task is responsible for setting its own non-success status, if applicable. - return; - } - await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, - validatorContext.RememberMeRequested); + // Each validation task is responsible for setting its own non-success status, if applicable. + return; } - else - { - // 1. We need to check if the user's master password hash is correct. - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) - { - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); - return; - } - - // 2. Decide if this user belongs to an organization that requires SSO. - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) - { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; - } - - // 3. Check if 2FA is required. - (validatorContext.TwoFactorRequired, var twoFactorOrganization) = - await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); - - // This flag is used to determine if the user wants a rememberMe token sent when - // authentication is successful. - var returnRememberMeToken = false; - - if (validatorContext.TwoFactorRequired) - { - var twoFactorToken = request.Raw["TwoFactorToken"]; - var twoFactorProvider = request.Raw["TwoFactorProvider"]; - var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - // 3a. Response for 2FA required and not provided state. - if (!validTwoFactorRequest || - !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - if (resultDict == null) - { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); - return; - } - - // Include Master Password Policy in 2FA response. - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - return; - } - - var twoFactorTokenValid = - await _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); - - // 3b. Response for 2FA required but request is not valid or remember token expired state. - if (!twoFactorTokenValid) - { - // The remember me token has expired. - if (twoFactorProviderType == TwoFactorProviderType.Remember) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - - // Include Master Password Policy in 2FA response - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - } - else - { - await SendFailedTwoFactorEmail(user, twoFactorProviderType); - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - } - - return; - } - - // 3c. When the 2FA authentication is successful, we can check if the user wants a - // rememberMe token. - var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; - // Check if the user wants a rememberMe token. - if (twoFactorRemember - // if the 2FA auth was rememberMe do not send another token. - && twoFactorProviderType != TwoFactorProviderType.Remember) - { - returnRememberMeToken = true; - } - } - - // 4. Check if the user is logging in from a new device. - var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); - if (!deviceValid) - { - SetValidationErrorResult(context, validatorContext); - await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); - return; - } - - // 5. Force legacy users to the web for migration. - if (UserService.IsLegacyUser(user) && request.ClientId != "web") - { - await FailAuthForLegacyUserAsync(user, context); - return; - } - - // TODO: PM-24324 - This should be its own validator at some point. - // 6. Auth request handling - if (validatorContext.ValidatedAuthRequest != null) - { - validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; - await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); - } - - await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); - } + await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, + validatorContext.RememberMeRequested); } protected async Task FailAuthForLegacyUserAsync(User user, T context) @@ -355,36 +233,56 @@ public abstract class BaseRequestValidator where T : class private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); - if (!validatorContext.SsoRequired) + // TODO: Clean up Feature Flag: Remove this if block: PM-28281 + if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) { - return true; - } - - // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are - // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and - // review their new recovery token if desired. - // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. - // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been - // evaluated, and recovery will have been performed if requested. - // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect - // to /login. - if (validatorContext.TwoFactorRequired && - validatorContext.TwoFactorRecoveryRequested) - { - SetSsoResult(context, new Dictionary + validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); + if (!validatorContext.SsoRequired) { - { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } - }); + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (validatorContext.TwoFactorRequired && + validatorContext.TwoFactorRecoveryRequested) + { + SetSsoResult(context, + new Dictionary + { + { + "ErrorModel", + new ErrorResponseModel( + "Two-factor recovery has been performed. SSO authentication is required.") + } + }); + return false; + } + + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); return false; } - - SetSsoResult(context, - new Dictionary + else + { + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return false; + return true; + } + + SetValidationErrorResult(context, validatorContext); + return ssoValid; + } } /// @@ -651,6 +549,8 @@ public abstract class BaseRequestValidator where T : class /// user trying to login /// magic string identifying the grant type requested /// true if sso required; false if not required or already in process + [Obsolete( + "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] private async Task RequireSsoLoginAsync(User user, string grantType) { if (grantType == "authorization_code" || grantType == "client_credentials") @@ -661,7 +561,7 @@ public abstract class BaseRequestValidator where T : class } // Check if user belongs to any organization with an active SSO policy - var ssoRequired = FeatureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await PolicyRequirementQuery.GetAsync(user.Id)) .SsoRequired : await PolicyService.AnyPoliciesApplicableToUserAsync( @@ -703,11 +603,8 @@ public abstract class BaseRequestValidator where T : class private async Task SendFailedTwoFactorEmail(User user, TwoFactorProviderType failedAttemptType) { - if (FeatureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) - { - await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, - CurrentContext.IpAddress); - } + await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, + CurrentContext.IpAddress); } private async Task GetMasterPasswordPolicyAsync(User user) diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 64156ea5f3..38a4813ecd 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -36,6 +36,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -56,6 +57,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public interface ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. Sets the error state in the if SSO is required. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context); +} diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index d69d521ef7..ea2c021f63 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -31,6 +31,7 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -50,6 +51,7 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator [!IMPORTANT] +> The string constants contained herein are used in conjunction with the Auth module in the SDK. Any change to these string values _must_ be intentional and _must_ have a corresponding change in the SDK. There is snapshot testing that will fail if the strings change to help detect unintended changes to the string constants. -# Custom Claims +## Custom Claims Send access tokens contain custom claims specific to the Send the Send grant type. @@ -19,41 +17,41 @@ Send access tokens contain custom claims specific to the Send the Send grant typ 1. `send_email` - only set when the Send requires `EmailOtp` authentication type. 1. `type` - this will always be `Send` -# Authentication methods +## Authentication methods -## `NeverAuthenticate` +### `NeverAuthenticate` For a Send to be in this state two things can be true: 1. The Send has been modified and no longer allows access. 2. The Send does not exist. -## `NotAuthenticated` +### `NotAuthenticated` In this scenario the Send is not protected by any added authentication or authorization and the access token is issued to the requesting user. -## `ResourcePassword` +### `ResourcePassword` In this scenario the Send is password protected and a user must supply the correct password hash to be issued an access token. -## `EmailOtp` +### `EmailOtp` In this scenario the Send is only accessible to owners of specific email addresses. The user must submit a correct email. Once the email has been entered then ownership of the email must be established via OTP. The Otp is sent to the aforementioned email and must be supplied, along with the email, to be issued an access token. -# Send Access Request Validation +## Send Access Request Validation -## Required Parameters +### Required Parameters -### All Requests +#### All Requests - `send_id` - Base64 URL-encoded GUID of the send being accessed -### Password Protected Sends +#### Password Protected Sends - `password_hash_b64` - client hashed Base64-encoded password. -### Email OTP Protected Sends +#### Email OTP Protected Sends - `email` - Email address associated with the send - `otp` - One-time password (optional - if missing, OTP is generated and sent) -## Error Responses +### Error Responses All errors include a custom response field: ```json @@ -62,5 +60,4 @@ All errors include a custom response field: "error_description": "Human readable description", "send_access_error_type": "specific_error_code" } -``` - +``` \ No newline at end of file diff --git a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs new file mode 100644 index 0000000000..145ecc8737 --- /dev/null +++ b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs @@ -0,0 +1,124 @@ +using Bit.Core; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Services; +using Bit.Identity.IdentityServer.RequestValidationConstants; +using Duende.IdentityModel; +using Duende.IdentityServer.Validation; + +namespace Bit.Identity.IdentityServer.RequestValidators; + +/// +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public class SsoRequestValidator( + IPolicyService _policyService, + IFeatureService _featureService, + IUserSsoOrganizationIdentifierQuery _userSsoOrganizationIdentifierQuery, + IPolicyRequirementQuery _policyRequirementQuery) : ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. + /// Sets context.SsoRequired to indicate whether SSO is required. + /// If SSO is required, sets the validation error result and custom response in the context. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + public async Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context) + { + context.SsoRequired = await RequireSsoAuthenticationAsync(user, request.GrantType); + + if (!context.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (context.TwoFactorRequired && context.TwoFactorRecoveryRequested) + { + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription); + return false; + } + + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoRequiredDescription); + return false; + } + + /// + /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are + /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. + /// If the GrantType is authorization_code or client_credentials we know the user is trying to log in + /// using the SSO flow so they are allowed to continue. + /// + /// user trying to log in + /// magic string identifying the grant type requested + /// true if sso required; false if not required or already in process + private async Task RequireSsoAuthenticationAsync(User user, string grantType) + { + if (grantType == OidcConstants.GrantTypes.AuthorizationCode || + grantType == OidcConstants.GrantTypes.ClientCredentials) + { + // SSO is not required for users already using SSO to authenticate which uses the authorization_code grant type, + // or logging-in via API key which is the client_credentials grant type. + // Allow user to continue request validation + return false; + } + + // Check if user belongs to any organization with an active SSO policy + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + ? (await _policyRequirementQuery.GetAsync(user.Id)) + .SsoRequired + : await _policyService.AnyPoliciesApplicableToUserAsync( + user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + + if (ssoRequired) + { + return true; + } + + // Default - SSO is not required + return false; + } + + /// + /// Sets the customResponse in the context with the error result for the SSO validation failure. + /// + /// The validator context to update with error details. + /// The error message to return to the client. + private async Task SetContextCustomResponseSsoErrorAsync(CustomValidatorRequestContext context, string errorMessage) + { + var ssoOrganizationIdentifier = await _userSsoOrganizationIdentifierQuery.GetSsoOrganizationIdentifierAsync(context.User.Id); + + context.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = OidcConstants.TokenErrors.InvalidGrant, + ErrorDescription = errorMessage + }; + + context.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(errorMessage) } + }; + + // Include organization identifier in the response if available + if (!string.IsNullOrEmpty(ssoOrganizationIdentifier)) + { + context.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier] = ssoOrganizationIdentifier; + } + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index 294df1c18d..e4cd60827e 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -38,6 +38,7 @@ public class WebAuthnGrantValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -59,6 +60,7 @@ public class WebAuthnGrantValidator : BaseRequestValidator { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.IpRateLimit; - } - - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.IdentityToken; - } - - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.Default; - })); - }); + }) + .AddSerilogFileLogging(); } } diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index 74344977a0..5dc443a73c 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -170,14 +170,11 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings, ILogger logger) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index e9056d030e..7e64975c95 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -26,6 +26,7 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient, SendPasswordRequestValidator>(); services.AddTransient, SendEmailOtpRequestValidator>(); diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json index 16c3efe46b..c21d2dff3b 100644 --- a/src/Identity/appsettings.json +++ b/src/Identity/appsettings.json @@ -27,9 +27,6 @@ "events": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs index 005e93c6aa..af24e11a0e 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -20,10 +20,9 @@ public class OrganizationIntegrationConfigurationRepository : Repository> GetConfigurationDetailsAsync( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) + public async Task> + GetManyByEventTypeOrganizationIdIntegrationType(EventType eventType, Guid organizationId, + IntegrationType integrationType) { using (var connection = new SqlConnection(ConnectionString)) { diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index ed5708844d..bd670347a9 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.Utilities.DebuggingInstruments; using Bit.Core.Entities; @@ -624,7 +625,11 @@ public class OrganizationUserRepository : Repository, IO await connection.ExecuteAsync( "[dbo].[OrganizationUser_SetStatusForUsersByGuidIdArray]", - new { OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP(), Status = OrganizationUserStatusType.Revoked }, + new + { + OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP(), + Status = OrganizationUserStatusType.Revoked + }, commandType: CommandType.StoredProcedure); } @@ -671,7 +676,7 @@ public class OrganizationUserRepository : Repository, IO commandType: CommandType.StoredProcedure); } - public async Task ConfirmOrganizationUserAsync(OrganizationUser organizationUser) + public async Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) { await using var connection = new SqlConnection(_marsConnectionString); @@ -679,12 +684,29 @@ public class OrganizationUserRepository : Repository, IO $"[{Schema}].[OrganizationUser_ConfirmById]", new { - organizationUser.Id, - organizationUser.UserId, + Id = organizationUserToConfirm.OrganizationUserId, + UserId = organizationUserToConfirm.UserId, RevisionDate = DateTime.UtcNow.Date, - Key = organizationUser.Key + Key = organizationUserToConfirm.Key }); return rowCount > 0; } + + public async Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationUserUserDetails_ReadByOrganizationIdUserId]", + new + { + OrganizationId = organizationId, + UserId = userId + }, + commandType: CommandType.StoredProcedure); + + return result; + } + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs index 467857612f..c05ff040e5 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs @@ -61,6 +61,18 @@ public class ProviderUserRepository : Repository, IProviderU } } + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + await using var connection = new SqlConnection(ConnectionString); + + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadManyByManyUserIds]", + new { UserIds = userIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) { using (var connection = new SqlConnection(ConnectionString)) diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs index 91cbc40ff6..a8171c286b 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs @@ -148,4 +148,16 @@ public class OrganizationDomainRepository : Repository commandType: CommandType.StoredProcedure) > 0; } } + + public async Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null) + { + await using var connection = new SqlConnection(ConnectionString); + + var result = await connection.QueryFirstOrDefaultAsync( + $"[{Schema}].[OrganizationDomain_HasVerifiedDomainWithBlockPolicy]", + new { DomainName = domainName, ExcludeOrganizationId = excludeOrganizationId }, + commandType: CommandType.StoredProcedure); + + return result; + } } diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 6b11d64cda..224351f034 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -1,17 +1,18 @@ using System.Data; using System.Text.Json; using Bit.Core; +using Bit.Core.Billing.Premium.Models; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Core.Utilities; using Dapper; using Microsoft.AspNetCore.DataProtection; using Microsoft.Data.SqlClient; -#nullable enable - namespace Bit.Infrastructure.Dapper.Repositories; public class UserRepository : Repository, IUserRepository @@ -288,6 +289,63 @@ public class UserRepository : Repository, IUserRepository UnprotectData(user); } + public async Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null) + { + if (!accountKeysData.IsV2Encryption()) + { + throw new ArgumentException("Provided account keys data is not valid V2 encryption data.", nameof(accountKeysData)); + } + + var timestamp = DateTime.UtcNow; + var signatureKeyPairId = CoreHelpers.GenerateComb(); + + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + + await using var transaction = connection.BeginTransaction(); + try + { + await connection.ExecuteAsync( + "[dbo].[User_UpdateAccountCryptographicState]", + new + { + Id = userId, + PublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.PublicKey, + PrivateKey = accountKeysData.PublicKeyEncryptionKeyPairData.WrappedPrivateKey, + SignedPublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.SignedPublicKey, + SecurityState = accountKeysData.SecurityStateData!.SecurityState, + SecurityVersion = accountKeysData.SecurityStateData!.SecurityVersion, + SignatureKeyPairId = signatureKeyPairId, + SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm, + SigningKey = accountKeysData.SignatureKeyPairData!.WrappedSigningKey, + VerifyingKey = accountKeysData.SignatureKeyPairData!.VerifyingKey, + RevisionDate = timestamp, + AccountRevisionDate = timestamp + }, + transaction: transaction, + commandType: CommandType.StoredProcedure); + + // Update user data that depends on cryptographic state + if (updateUserDataActions != null) + { + foreach (var action in updateUserDataActions) + { + await action(connection, transaction); + } + } + + await transaction.CommitAsync(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } + } + public async Task> GetManyAsync(IEnumerable ids) { using (var connection = new SqlConnection(ReadOnlyConnectionString)) @@ -324,6 +382,25 @@ public class UserRepository : Repository, IUserRepository return result.SingleOrDefault(); } + public async Task> GetPremiumAccessByIdsAsync(IEnumerable ids) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadPremiumAccessByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetPremiumAccessAsync(Guid userId) + { + var result = await GetPremiumAccessByIdsAsync([userId]); + return result.SingleOrDefault(); + } + private async Task ProtectDataAndSaveAsync(User user, Func saveTask) { if (user == null) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs index 47369f5e3d..93d8fe2d7d 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs @@ -18,7 +18,7 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration new { o.Id, o.Enabled }), - o => o.UseTotp); + o => new { o.UseTotp, o.UsersGetPremium }); builder.ToTable(nameof(Organization)); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs index fc391b958c..ff8f92fd91 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -17,16 +17,17 @@ public class OrganizationIntegrationConfigurationRepository : Repository context.OrganizationIntegrationConfigurations) { } - public async Task> GetConfigurationDetailsAsync( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) + public async Task> + GetManyByEventTypeOrganizationIdIntegrationType(EventType eventType, Guid organizationId, + IntegrationType integrationType) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var query = new OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery( - organizationId, eventType, integrationType + organizationId, + eventType, + integrationType ); return await query.Run(dbContext).ToListAsync(); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index ebc2bc6606..f2da58a1dd 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -113,7 +113,8 @@ public class OrganizationRepository : Repository ConfirmOrganizationUserAsync(Core.Entities.OrganizationUser organizationUser) + public async Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) { using var scope = ServiceScopeFactory.CreateScope(); await using var dbContext = GetDatabaseContext(scope); var result = await dbContext.OrganizationUsers - .Where(ou => ou.Id == organizationUser.Id && ou.Status == OrganizationUserStatusType.Accepted) + .Where(ou => ou.Id == organizationUserToConfirm.OrganizationUserId + && ou.Status == OrganizationUserStatusType.Accepted) .ExecuteUpdateAsync(x => x .SetProperty(y => y.Status, OrganizationUserStatusType.Confirmed) - .SetProperty(y => y.Key, organizationUser.Key)); + .SetProperty(y => y.Key, organizationUserToConfirm.Key)); if (result <= 0) { return false; } - await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUser.Id); + await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserToConfirm.OrganizationUserId); return true; } + +#nullable enable + + public async Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var entity = await view.Run(dbContext).SingleOrDefaultAsync(ou => ou.OrganizationId == organizationId && ou.UserId == userId); + return entity; + } + } +#nullable disable + + } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs index 1cca7a9bbb..894fb255be 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs @@ -217,7 +217,7 @@ public class PolicyRepository : Repository new OrganizationPolicyDetails { diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs index 5474e3e217..8f9a38f9b6 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs @@ -96,6 +96,20 @@ public class ProviderUserRepository : return await query.ToArrayAsync(); } } + + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + + var dbContext = GetDatabaseContext(scope); + + var query = from pu in dbContext.ProviderUsers + where pu.UserId != null && userIds.Contains(pu.UserId.Value) + select pu; + + return await query.ToArrayAsync(); + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs index b4441c5084..421bb9407a 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs @@ -1,31 +1,21 @@ -#nullable enable - -using Bit.Core.Enums; +using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations; namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; -public class OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery : IQuery +public class OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery( + Guid organizationId, + EventType eventType, + IntegrationType integrationType) + : IQuery { - private readonly Guid _organizationId; - private readonly EventType _eventType; - private readonly IntegrationType _integrationType; - - public OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery(Guid organizationId, EventType eventType, IntegrationType integrationType) - { - _organizationId = organizationId; - _eventType = eventType; - _integrationType = integrationType; - } - public IQueryable Run(DatabaseContext dbContext) { var query = from oic in dbContext.OrganizationIntegrationConfigurations - join oi in dbContext.OrganizationIntegrations on oic.OrganizationIntegrationId equals oi.Id into oioic - from oi in dbContext.OrganizationIntegrations - where oi.OrganizationId == _organizationId && - oi.Type == _integrationType && - oic.EventType == _eventType + join oi in dbContext.OrganizationIntegrations on oic.OrganizationIntegrationId equals oi.Id + where oi.OrganizationId == organizationId && + oi.Type == integrationType && + (oic.EventType == eventType || oic.EventType == null) select new OrganizationIntegrationConfigurationDetails() { Id = oic.Id, diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs index 504a75c9f2..f433e9096b 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs @@ -74,7 +74,8 @@ public class OrganizationUserOrganizationDetailsViewQuery : IQuery>(verifiedDomains); } + public async Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var query = from od in dbContext.OrganizationDomains + join o in dbContext.Organizations on od.OrganizationId equals o.Id + join p in dbContext.Policies on o.Id equals p.OrganizationId + where od.DomainName == domainName + && od.VerifiedDate != null + && o.Enabled + && o.UsePolicies + && o.UseOrganizationDomains + && (!excludeOrganizationId.HasValue || o.Id != excludeOrganizationId.Value) + && p.Type == Core.AdminConsole.Enums.PolicyType.BlockClaimedDomainAccountCreation + && p.Enabled + select od; + + return await query.AnyAsync(); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index 809704edb7..9bf093e506 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -1,4 +1,6 @@ using AutoMapper; +using Bit.Core.Billing.Premium.Models; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; using Bit.Core.Repositories; @@ -241,6 +243,80 @@ public class UserRepository : Repository, IUserR await transaction.CommitAsync(); } + public async Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null) + { + if (!accountKeysData.IsV2Encryption()) + { + throw new ArgumentException("Provided account keys data is not valid V2 encryption data.", nameof(accountKeysData)); + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + await using var transaction = await dbContext.Database.BeginTransactionAsync(); + + // Update user + var userEntity = await dbContext.Users.FindAsync(userId); + if (userEntity == null) + { + throw new ArgumentException("User not found", nameof(userId)); + } + + // Update public key encryption key pair + var timestamp = DateTime.UtcNow; + + userEntity.RevisionDate = timestamp; + userEntity.AccountRevisionDate = timestamp; + + // V1 + V2 user crypto changes + userEntity.PublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.PublicKey; + userEntity.PrivateKey = accountKeysData.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + + userEntity.SecurityState = accountKeysData.SecurityStateData!.SecurityState; + userEntity.SecurityVersion = accountKeysData.SecurityStateData.SecurityVersion; + userEntity.SignedPublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.SignedPublicKey; + + // Replace existing keypair if it exists + var existingKeyPair = await dbContext.UserSignatureKeyPairs + .FirstOrDefaultAsync(x => x.UserId == userId); + if (existingKeyPair != null) + { + existingKeyPair.SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm; + existingKeyPair.SigningKey = accountKeysData.SignatureKeyPairData.WrappedSigningKey; + existingKeyPair.VerifyingKey = accountKeysData.SignatureKeyPairData.VerifyingKey; + existingKeyPair.RevisionDate = timestamp; + } + else + { + var newKeyPair = new UserSignatureKeyPair + { + UserId = userId, + SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm, + SigningKey = accountKeysData.SignatureKeyPairData.WrappedSigningKey, + VerifyingKey = accountKeysData.SignatureKeyPairData.VerifyingKey, + CreationDate = timestamp, + RevisionDate = timestamp + }; + newKeyPair.SetNewId(); + await dbContext.UserSignatureKeyPairs.AddAsync(newKeyPair); + } + + await dbContext.SaveChangesAsync(); + + // Update additional user data within the same transaction + if (updateUserDataActions != null) + { + foreach (var action in updateUserDataActions) + { + await action(); + } + } + await transaction.CommitAsync(); + } + public async Task> GetManyAsync(IEnumerable ids) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -275,6 +351,36 @@ public class UserRepository : Repository, IUserR return result.FirstOrDefault(); } + public async Task> GetPremiumAccessByIdsAsync(IEnumerable ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var users = await dbContext.Users + .Where(x => ids.Contains(x.Id)) + .Include(u => u.OrganizationUsers) + .ThenInclude(ou => ou.Organization) + .ToListAsync(); + + return users.Select(user => new UserPremiumAccess + { + Id = user.Id, + PersonalPremium = user.Premium, + OrganizationPremium = user.OrganizationUsers + .Any(ou => ou.Organization != null && + ou.Organization.Enabled == true && + ou.Organization.UsersGetPremium == true) + }).ToList(); + } + } + + public async Task GetPremiumAccessAsync(Guid userId) + { + var result = await GetPremiumAccessByIdsAsync([userId]); + return result.FirstOrDefault(); + } + public override async Task DeleteAsync(Core.Entities.User user) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index 3c45afe530..ebe39852f4 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -704,6 +704,9 @@ public class CipherRepository : Repository>(notificationJson, _deserializerOptions); + if (policyData is null) + { + return; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(policyData.Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, policyData, cancellationToken); + + } } diff --git a/src/Notifications/Program.cs b/src/Notifications/Program.cs index 072c2404c4..2792391729 100644 --- a/src/Notifications/Program.cs +++ b/src/Notifications/Program.cs @@ -1,5 +1,4 @@ using Bit.Core.Utilities; -using Serilog.Events; namespace Bit.Notifications; @@ -13,37 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.NotificationsSettings.IdentityToken; - } - - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text == "Failed connection handshake.") - { - return false; - } - - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text.StartsWith("Failed writing message.")) - { - return false; - } - - if (e.Level == LogEventLevel.Warning && - e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.NotificationsSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index 2889e90d3b..65904ea698 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -82,11 +82,9 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); diff --git a/src/Notifications/appsettings.json b/src/Notifications/appsettings.json index 020d98cbd6..e36ec02dad 100644 --- a/src/Notifications/appsettings.json +++ b/src/Notifications/appsettings.json @@ -18,9 +18,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "amazon": { "accessKeyId": "SECRET", "accessKeySecret": "SECRET", diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 78b8a61015..91047d98bc 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -6,12 +6,9 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; -using Azure.Messaging.ServiceBus; using Bit.Core; using Bit.Core.AdminConsole.AbilitiesCache; using Bit.Core.AdminConsole.Models.Business.Tokenables; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.AdminConsole.Models.Teams; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; @@ -20,7 +17,6 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Auth.IdentityServer; -using Bit.Core.Auth.LoginFeatures; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; @@ -74,8 +70,6 @@ using Microsoft.AspNetCore.HttpOverrides; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc.Localization; using Microsoft.Azure.Cosmos.Fluent; -using Microsoft.Bot.Builder; -using Microsoft.Bot.Builder.Integration.AspNet.Core; using Microsoft.Extensions.Caching.Cosmos; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Configuration; @@ -84,7 +78,9 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.OpenApi.Models; using StackExchange.Redis; +using Swashbuckle.AspNetCore.SwaggerGen; using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; using TableStorageRepos = Bit.Core.Repositories.TableStorage; @@ -136,7 +132,6 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddLoginServices(); services.AddScoped(); services.AddVaultServices(); services.AddReportingServices(); @@ -243,7 +238,7 @@ public static class ServiceCollectionExtensions PrivateKey = globalSettings.Braintree.PrivateKey }; }); - services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); // Legacy mailer service @@ -340,6 +335,7 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); } @@ -522,116 +518,6 @@ public static class ServiceCollectionExtensions return globalSettings; } - public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) - { - if (IsAzureServiceBusEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddSingleton(); - return services; - } - - if (IsRabbitMqEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddSingleton(); - return services; - } - - if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.TryAddSingleton(); - return services; - } - - if (globalSettings.SelfHosted) - { - services.TryAddSingleton(); - return services; - } - - services.TryAddSingleton(); - return services; - } - - public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsAzureServiceBusEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsRabbitMqEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && - CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && - CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) - { - services.AddHttpClient(SlackService.HttpClientName); - services.TryAddSingleton(); - } - else - { - services.TryAddSingleton(); - } - - return services; - } - - public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && - CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && - CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) - { - services.AddHttpClient(TeamsService.HttpClientName); - services.TryAddSingleton(); - services.TryAddSingleton(sp => sp.GetRequiredService()); - services.TryAddSingleton(sp => sp.GetRequiredService()); - services.TryAddSingleton(sp => - new BotFrameworkHttpAdapter( - new TeamsBotCredentialProvider( - clientId: globalSettings.Teams.ClientId, - clientSecret: globalSettings.Teams.ClientSecret - ) - ) - ); - } - else - { - services.TryAddSingleton(); - } - - return services; - } - public static void UseDefaultMiddleware(this IApplicationBuilder app, IWebHostEnvironment env, GlobalSettings globalSettings) { @@ -645,7 +531,7 @@ public static class ServiceCollectionExtensions ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto }; - if (!globalSettings.UnifiedDeployment) + if (!globalSettings.LiteDeployment) { // Trust the X-Forwarded-Host header of the nginx docker container try @@ -878,185 +764,60 @@ public static class ServiceCollectionExtensions return (provider, connectionString); } - private static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration + /// + /// Adds a server with its corresponding OAuth2 client credentials security definition and requirement. + /// + /// The SwaggerGen configuration + /// Unique identifier for this server (e.g., "us-server", "eu-server") + /// The API server URL + /// The identity server token URL + /// Human-readable description for the server + public static void AddSwaggerServerWithSecurity( + this SwaggerGenOptions config, + string serverId, + string serverUrl, + string identityTokenUrl, + string serverDescription) { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - configurationCache: provider.GetRequiredService(), - userRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - logger: provider.GetRequiredService>>() - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.EventPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusIntegrationListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredService>(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, - GlobalSettings globalSettings) - { - // Add common services - services.TryAddSingleton(); - services.TryAddSingleton(provider => - provider.GetRequiredService()); - services.AddHostedService(provider => provider.GetRequiredService()); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - - // Add services in support of handlers - services.AddSlackService(globalSettings); - services.AddTeamsService(globalSettings); - services.TryAddSingleton(TimeProvider.System); - services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); - services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); - - // Add integration handlers - services.TryAddSingleton, SlackIntegrationHandler>(); - services.TryAddSingleton, WebhookIntegrationHandler>(); - services.TryAddSingleton, DatadogIntegrationHandler>(); - services.TryAddSingleton, TeamsIntegrationHandler>(); - - var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); - var slackConfiguration = new SlackListenerConfiguration(globalSettings); - var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); - var hecConfiguration = new HecListenerConfiguration(globalSettings); - var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); - var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); - - if (IsRabbitMqEnabled(globalSettings)) + // Add server + config.AddServer(new OpenApiServer { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredService(), - configuration: repositoryConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddRabbitMqIntegration(slackConfiguration); - services.AddRabbitMqIntegration(webhookConfiguration); - services.AddRabbitMqIntegration(hecConfiguration); - services.AddRabbitMqIntegration(datadogConfiguration); - services.AddRabbitMqIntegration(teamsConfiguration); - } + Url = serverUrl, + Description = serverDescription + }); - if (IsAzureServiceBusEnabled(globalSettings)) + // Add security definition + config.AddSecurityDefinition(serverId, new OpenApiSecurityScheme { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: repositoryConfiguration, - handler: provider.GetRequiredService(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = repositoryConfiguration.EventPrefetchCount, - MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddAzureServiceBusIntegration(slackConfiguration); - services.AddAzureServiceBusIntegration(webhookConfiguration); - services.AddAzureServiceBusIntegration(hecConfiguration); - services.AddAzureServiceBusIntegration(datadogConfiguration); - services.AddAzureServiceBusIntegration(teamsConfiguration); - } + Type = SecuritySchemeType.OAuth2, + Description = $"**Use this option if you've selected the {serverDescription}**", + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri(identityTokenUrl), + Scopes = new Dictionary + { + { ApiScopes.ApiOrganization, $"Organization APIs ({serverDescription})" }, + }, + } + }, + }); - return services; - } - - private static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration - { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - configurationCache: provider.GetRequiredService(), - userRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - logger: provider.GetRequiredService>>() - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqIntegrationListenerService( - handler: provider.GetRequiredService>(), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService(), - timeProvider: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static bool IsAzureServiceBusEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName); - } - - private static bool IsRabbitMqEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName); + // Add security requirement + config.AddSecurityRequirement(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = serverId + }, + }, + [ApiScopes.ApiOrganization] + } + }); } } diff --git a/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql b/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql new file mode 100644 index 0000000000..bfa9d932c5 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql @@ -0,0 +1,34 @@ +CREATE PROCEDURE [dbo].[OrganizationDomain_HasVerifiedDomainWithBlockPolicy] + @DomainName NVARCHAR(255), + @ExcludeOrganizationId UNIQUEIDENTIFIER = NULL +AS +BEGIN + SET NOCOUNT ON + + -- Check if any organization has a verified domain matching the domain name + -- with the BlockClaimedDomainAccountCreation policy enabled (Type = 19) + -- If @ExcludeOrganizationId is provided, exclude that organization from the check + IF EXISTS ( + SELECT 1 + FROM [dbo].[OrganizationDomain] OD + INNER JOIN [dbo].[Organization] O + ON OD.OrganizationId = O.Id + INNER JOIN [dbo].[Policy] P + ON O.Id = P.OrganizationId + WHERE OD.DomainName = @DomainName + AND OD.VerifiedDate IS NOT NULL + AND O.Enabled = 1 + AND O.UsePolicies = 1 + AND O.UseOrganizationDomains = 1 + AND (@ExcludeOrganizationId IS NULL OR O.Id != @ExcludeOrganizationId) + AND P.Type = 19 -- BlockClaimedDomainAccountCreation + AND P.Enabled = 1 + ) + BEGIN + SELECT CAST(1 AS BIT) AS HasBlockPolicy + END + ELSE + BEGIN + SELECT CAST(0 AS BIT) AS HasBlockPolicy + END +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql b/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql index 3240402916..7124be73fb 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql @@ -11,7 +11,7 @@ BEGIN FROM [dbo].[OrganizationIntegrationConfigurationDetailsView] oic WHERE - oic.[EventType] = @EventType + (oic.[EventType] = @EventType OR oic.[EventType] IS NULL) AND oic.[OrganizationId] = @OrganizationId AND diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql new file mode 100644 index 0000000000..6113664b76 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE [dbo].[OrganizationUserUserDetails_ReadByOrganizationIdUserId] + @OrganizationId UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + +SELECT + * +FROM + [dbo].[OrganizationUserUserDetailsView] +WHERE + [OrganizationId] = @OrganizationId +AND + [UserId] = @UserId +END +GO diff --git a/src/Sql/dbo/Stored Procedures/Organization_Create.sql b/src/Sql/dbo/Stored Procedures/Organization_Create.sql index e37fa0e940..decd406280 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Create.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Create.sql @@ -59,7 +59,8 @@ CREATE PROCEDURE [dbo].[Organization_Create] @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, @SyncSeats BIT = 0, - @UseAutomaticUserConfirmation BIT = 0 + @UseAutomaticUserConfirmation BIT = 0, + @UsePhishingBlocker BIT = 0 AS BEGIN SET NOCOUNT ON @@ -126,7 +127,8 @@ BEGIN [UseOrganizationDomains], [UseAdminSponsoredFamilies], [SyncSeats], - [UseAutomaticUserConfirmation] + [UseAutomaticUserConfirmation], + [UsePhishingBlocker] ) VALUES ( @@ -190,6 +192,7 @@ BEGIN @UseOrganizationDomains, @UseAdminSponsoredFamilies, @SyncSeats, - @UseAutomaticUserConfirmation + @UseAutomaticUserConfirmation, + @UsePhishingBlocker ); END diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql index 59226e59db..9efefe8d54 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql @@ -28,7 +28,8 @@ BEGIN [LimitItemDeletion], [UseOrganizationDomains], [UseAdminSponsoredFamilies], - [UseAutomaticUserConfirmation] + [UseAutomaticUserConfirmation], + [UsePhishingBlocker] FROM [dbo].[Organization] END diff --git a/src/Sql/dbo/Stored Procedures/Organization_Update.sql b/src/Sql/dbo/Stored Procedures/Organization_Update.sql index 4807c7bb50..9fd1b59460 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Update.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Update.sql @@ -59,7 +59,8 @@ CREATE PROCEDURE [dbo].[Organization_Update] @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, @SyncSeats BIT = 0, - @UseAutomaticUserConfirmation BIT = 0 + @UseAutomaticUserConfirmation BIT = 0, + @UsePhishingBlocker BIT = 0 AS BEGIN SET NOCOUNT ON @@ -126,7 +127,8 @@ BEGIN [UseOrganizationDomains] = @UseOrganizationDomains, [UseAdminSponsoredFamilies] = @UseAdminSponsoredFamilies, [SyncSeats] = @SyncSeats, - [UseAutomaticUserConfirmation] = @UseAutomaticUserConfirmation + [UseAutomaticUserConfirmation] = @UseAutomaticUserConfirmation, + [UsePhishingBlocker] = @UsePhishingBlocker WHERE [Id] = @Id; END diff --git a/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql new file mode 100644 index 0000000000..4fe8d153e4 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[ProviderUser_ReadManyByManyUserIds] + @UserIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + SELECT + [pu].* + FROM + [dbo].[ProviderUserView] AS [pu] + INNER JOIN + @UserIds [u] ON [u].[Id] = [pu].[UserId] +END diff --git a/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql b/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql new file mode 100644 index 0000000000..a4c73c39df --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[User_ReadPremiumAccessByIds] + @Ids [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + SELECT + UPA.[Id], + UPA.[PersonalPremium], + UPA.[OrganizationPremium] + FROM + [dbo].[UserPremiumAccessView] UPA + WHERE + UPA.[Id] IN (SELECT [Id] FROM @Ids) +END diff --git a/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql b/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql new file mode 100644 index 0000000000..8f1fb664ea --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql @@ -0,0 +1,65 @@ +CREATE PROCEDURE [dbo].[User_UpdateAccountCryptographicState] + @Id UNIQUEIDENTIFIER, + @PublicKey NVARCHAR(MAX), + @PrivateKey NVARCHAR(MAX), + @SignedPublicKey NVARCHAR(MAX) = NULL, + @SecurityState NVARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignatureKeyPairId UNIQUEIDENTIFIER = NULL, + @SignatureAlgorithm TINYINT = NULL, + @SigningKey VARCHAR(MAX) = NULL, + @VerifyingKey VARCHAR(MAX) = NULL, + @RevisionDate DATETIME2(7), + @AccountRevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[User] + SET + [PublicKey] = @PublicKey, + [PrivateKey] = @PrivateKey, + [SignedPublicKey] = @SignedPublicKey, + [SecurityState] = @SecurityState, + [SecurityVersion] = @SecurityVersion, + [RevisionDate] = @RevisionDate, + [AccountRevisionDate] = @AccountRevisionDate + WHERE + [Id] = @Id + + IF EXISTS (SELECT 1 FROM [dbo].[UserSignatureKeyPair] WHERE [UserId] = @Id) + BEGIN + UPDATE [dbo].[UserSignatureKeyPair] + SET + [SignatureAlgorithm] = @SignatureAlgorithm, + [SigningKey] = @SigningKey, + [VerifyingKey] = @VerifyingKey, + [RevisionDate] = @RevisionDate + WHERE + [UserId] = @Id + END + ELSE + BEGIN + INSERT INTO [dbo].[UserSignatureKeyPair] + ( + [Id], + [UserId], + [SignatureAlgorithm], + [SigningKey], + [VerifyingKey], + [CreationDate], + [RevisionDate] + ) + VALUES + ( + @SignatureKeyPairId, + @Id, + @SignatureAlgorithm, + @SigningKey, + @VerifyingKey, + @RevisionDate, + @RevisionDate + ) + END +END diff --git a/src/Sql/dbo/Tables/Organization.sql b/src/Sql/dbo/Tables/Organization.sql index e1ad6863af..f07cd4ce0d 100644 --- a/src/Sql/dbo/Tables/Organization.sql +++ b/src/Sql/dbo/Tables/Organization.sql @@ -60,6 +60,8 @@ CREATE TABLE [dbo].[Organization] ( [UseAdminSponsoredFamilies] BIT NOT NULL CONSTRAINT [DF_Organization_UseAdminSponsoredFamilies] DEFAULT (0), [SyncSeats] BIT NOT NULL CONSTRAINT [DF_Organization_SyncSeats] DEFAULT (0), [UseAutomaticUserConfirmation] BIT NOT NULL CONSTRAINT [DF_Organization_UseAutomaticUserConfirmation] DEFAULT (0), + [MaxStorageGbIncreased] SMALLINT NULL, + [UsePhishingBlocker] BIT NOT NULL CONSTRAINT [DF_Organization_UsePhishingBlocker] DEFAULT (0), CONSTRAINT [PK_Organization] PRIMARY KEY CLUSTERED ([Id] ASC) ); @@ -67,7 +69,7 @@ CREATE TABLE [dbo].[Organization] ( GO CREATE NONCLUSTERED INDEX [IX_Organization_Enabled] ON [dbo].[Organization]([Id] ASC, [Enabled] ASC) - INCLUDE ([UseTotp]); + INCLUDE ([UseTotp], [UsersGetPremium]); GO CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier] diff --git a/src/Sql/dbo/Tables/User.sql b/src/Sql/dbo/Tables/User.sql index dc772ff1a7..854fe34f4a 100644 --- a/src/Sql/dbo/Tables/User.sql +++ b/src/Sql/dbo/Tables/User.sql @@ -45,6 +45,7 @@ [SecurityState] VARCHAR (MAX) NULL, [SecurityVersion] INT NULL, [SignedPublicKey] VARCHAR (MAX) NULL, + [MaxStorageGbIncreased] SMALLINT NULL, CONSTRAINT [PK_User] PRIMARY KEY CLUSTERED ([Id] ASC) ); diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql index ee7e00b32a..6082e89efc 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql @@ -30,4 +30,10 @@ BEGIN DECLARE @UpdateCollectionsSuccess INT EXEC @UpdateCollectionsSuccess = [dbo].[Cipher_UpdateCollections] @Id, @UserId, @OrganizationId, @CollectionIds + + -- Bump the account revision date AFTER collections are assigned. + IF @UpdateCollectionsSuccess = 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDateByCipherId] @Id, @OrganizationId + END END diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql index 55852c4d27..3fe877c168 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql @@ -38,8 +38,13 @@ BEGIN [Data] = @Data, [Attachments] = @Attachments, [RevisionDate] = @RevisionDate, - [DeletedDate] = @DeletedDate, [Key] = @Key, [ArchivedDate] = @ArchivedDate - -- No need to update CreationDate, Favorites, Folders, or Type since that data will not change + [DeletedDate] = @DeletedDate, + [Key] = @Key, + [ArchivedDate] = @ArchivedDate, + [Folders] = @Folders, + [Favorites] = @Favorites, + [Reprompt] = @Reprompt + -- No need to update CreationDate or Type since that data will not change WHERE [Id] = @Id diff --git a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql index a7e1db6e81..ffd6810b1b 100644 --- a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql @@ -24,7 +24,7 @@ SELECT O.[UseSecretsManager], O.[Seats], O.[MaxCollections], - O.[MaxStorageGb], + COALESCE(O.[MaxStorageGbIncreased], O.[MaxStorageGb]) AS [MaxStorageGb], O.[Identifier], OU.[Key], OU.[ResetPasswordKey], @@ -55,7 +55,8 @@ SELECT O.[UseAdminSponsoredFamilies], O.[UseOrganizationDomains], OS.[IsAdminInitiated], - O.[UseAutomaticUserConfirmation] + O.[UseAutomaticUserConfirmation], + O.[UsePhishingBlocker] FROM [dbo].[OrganizationUser] OU LEFT JOIN diff --git a/src/Sql/dbo/Views/OrganizationView.sql b/src/Sql/dbo/Views/OrganizationView.sql index 58989273fd..6e42d08338 100644 --- a/src/Sql/dbo/Views/OrganizationView.sql +++ b/src/Sql/dbo/Views/OrganizationView.sql @@ -1,6 +1,67 @@ CREATE VIEW [dbo].[OrganizationView] AS SELECT - * + [Id], + [Identifier], + [Name], + [BusinessName], + [BusinessAddress1], + [BusinessAddress2], + [BusinessAddress3], + [BusinessCountry], + [BusinessTaxNumber], + [BillingEmail], + [Plan], + [PlanType], + [Seats], + [MaxCollections], + [UsePolicies], + [UseSso], + [UseGroups], + [UseDirectory], + [UseEvents], + [UseTotp], + [Use2fa], + [UseApi], + [UseResetPassword], + [SelfHost], + [UsersGetPremium], + [Storage], + COALESCE([MaxStorageGbIncreased], [MaxStorageGb]) AS [MaxStorageGb], + [Gateway], + [GatewayCustomerId], + [GatewaySubscriptionId], + [ReferenceData], + [Enabled], + [LicenseKey], + [PublicKey], + [PrivateKey], + [TwoFactorProviders], + [ExpirationDate], + [CreationDate], + [RevisionDate], + [OwnersNotifiedOfAutoscaling], + [MaxAutoscaleSeats], + [UseKeyConnector], + [UseScim], + [UseCustomPermissions], + [UseSecretsManager], + [Status], + [UsePasswordManager], + [SmSeats], + [SmServiceAccounts], + [MaxAutoscaleSmSeats], + [MaxAutoscaleSmServiceAccounts], + [SecretsManagerBeta], + [LimitCollectionCreation], + [LimitCollectionDeletion], + [LimitItemDeletion], + [AllowAdminAccessToAllCollectionItems], + [UseRiskInsights], + [UseOrganizationDomains], + [UseAdminSponsoredFamilies], + [SyncSeats], + [UseAutomaticUserConfirmation], + [UsePhishingBlocker] FROM [dbo].[Organization] diff --git a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql index 42e877ab15..e1d5ef9144 100644 --- a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql @@ -23,7 +23,7 @@ SELECT O.[UseCustomPermissions], O.[Seats], O.[MaxCollections], - O.[MaxStorageGb], + COALESCE(O.[MaxStorageGbIncreased], O.[MaxStorageGb]) AS [MaxStorageGb], O.[Identifier], PO.[Key], O.[PublicKey], @@ -44,7 +44,8 @@ SELECT O.[UseOrganizationDomains], O.[UseAutomaticUserConfirmation], SS.[Enabled] SsoEnabled, - SS.[Data] SsoConfig + SS.[Data] SsoConfig, + O.[UsePhishingBlocker] FROM [dbo].[ProviderUser] PU INNER JOIN diff --git a/src/Sql/dbo/Views/UserPremiumAccessView.sql b/src/Sql/dbo/Views/UserPremiumAccessView.sql new file mode 100644 index 0000000000..a20cab8fb3 --- /dev/null +++ b/src/Sql/dbo/Views/UserPremiumAccessView.sql @@ -0,0 +1,21 @@ +CREATE VIEW [dbo].[UserPremiumAccessView] +AS +SELECT + U.[Id], + U.[Premium] AS [PersonalPremium], + CAST( + MAX(CASE + WHEN O.[Id] IS NOT NULL THEN 1 + ELSE 0 + END) AS BIT + ) AS [OrganizationPremium] +FROM + [dbo].[User] U +LEFT JOIN + [dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id] +LEFT JOIN + [dbo].[Organization] O ON O.[Id] = OU.[OrganizationId] + AND O.[UsersGetPremium] = 1 + AND O.[Enabled] = 1 +GROUP BY + U.[Id], U.[Premium]; diff --git a/src/Sql/dbo/Views/UserView.sql b/src/Sql/dbo/Views/UserView.sql index 82fa8a2c63..fa8dbf334b 100644 --- a/src/Sql/dbo/Views/UserView.sql +++ b/src/Sql/dbo/Views/UserView.sql @@ -1,6 +1,51 @@ CREATE VIEW [dbo].[UserView] AS SELECT - * + [Id], + [Name], + [Email], + [EmailVerified], + [MasterPassword], + [MasterPasswordHint], + [Culture], + [SecurityStamp], + [TwoFactorProviders], + [TwoFactorRecoveryCode], + [EquivalentDomains], + [ExcludedGlobalEquivalentDomains], + [AccountRevisionDate], + [Key], + [PublicKey], + [PrivateKey], + [Premium], + [PremiumExpirationDate], + [RenewalReminderDate], + [Storage], + COALESCE([MaxStorageGbIncreased], [MaxStorageGb]) AS [MaxStorageGb], + [Gateway], + [GatewayCustomerId], + [GatewaySubscriptionId], + [ReferenceData], + [LicenseKey], + [ApiKey], + [Kdf], + [KdfIterations], + [KdfMemory], + [KdfParallelism], + [CreationDate], + [RevisionDate], + [ForcePasswordReset], + [UsesKeyConnector], + [FailedLoginCount], + [LastFailedLoginDate], + [AvatarColor], + [LastPasswordChangeDate], + [LastKdfChangeDate], + [LastKeyRotationDate], + [LastEmailChangeDate], + [VerifyDevices], + [SecurityState], + [SecurityVersion], + [SignedPublicKey] FROM [dbo].[User] diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs new file mode 100644 index 0000000000..71c6bf104c --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs @@ -0,0 +1,63 @@ +using System.Net; +using System.Text; +using System.Text.Json; +using Bit.Api.AdminConsole.Models.Request; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Request; +using Bit.Seeder.Recipes; +using Xunit; +using Xunit.Abstractions; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class GroupsControllerPerformanceTests(ITestOutputHelper testOutputHelper) +{ + /// + /// Tests PUT /organizations/{orgId}/groups/{id} + /// + [Theory(Skip = "Performance test")] + [InlineData(10, 5)] + //[InlineData(100, 10)] + //[InlineData(1000, 20)] + public async Task UpdateGroup_WithUsersAndCollections(int userCount, int collectionCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + var groupIds = groupsSeeder.AddToOrganization(orgId, 1, orgUserIds, 0); + + var groupId = groupIds.First(); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var updateRequest = new GroupRequestModel + { + Name = "Updated Group Name", + Collections = collectionIds.Select(c => new SelectionReadOnlyRequestModel { Id = c, ReadOnly = false, HidePasswords = false, Manage = false }), + Users = orgUserIds + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(updateRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/groups/{groupId}", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /organizations/{{orgId}}/groups/{{id}} - Users: {orgUserIds.Count}; Collections: {collectionIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs new file mode 100644 index 0000000000..8df1fcaf2b --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs @@ -0,0 +1,225 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUserControllerAutoConfirmTests : IClassFixture, IAsyncLifetime +{ + private const string _mockEncryptedString = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private string _ownerEmail = null!; + + public OrganizationUserControllerAutoConfirmTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-owner-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + } + + [Fact] + public async Task AutoConfirm_WhenUserCannotManageOtherUsers_ThenShouldReturnForbidden() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + var (confirmingUserEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, organization.Id, OrganizationUserType.User); + await _loginHelper.LoginAsync(confirmingUserEmail); + + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions { ManageUsers = false }, + OrganizationUserStatusType.Accepted); + + var result = await _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + }); + + Assert.Equal(HttpStatusCode.Forbidden, result.StatusCode); + + await _factory.GetService().DeleteAsync(organization); + } + + [Fact] + public async Task AutoConfirm_WhenOwnerConfirmsValidUser_ThenShouldReturnNoContent() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.AutomaticUserConfirmation, + Enabled = true + }); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.OrganizationDataOwnership, + Enabled = true + }); + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + await _loginHelper.LoginAsync(_ownerEmail); + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions(), + OrganizationUserStatusType.Accepted); + + var result = await _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + }); + + Assert.Equal(HttpStatusCode.NoContent, result.StatusCode); + + var orgUserRepository = _factory.GetService(); + var confirmedUser = await orgUserRepository.GetByIdAsync(organizationUser.Id); + Assert.NotNull(confirmedUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, confirmedUser.Status); + Assert.Equal(testKey, confirmedUser.Key); + + var collectionRepository = _factory.GetService(); + var collections = await collectionRepository.GetManyByUserIdAsync(organizationUser.UserId!.Value); + + Assert.NotEmpty(collections); + Assert.Single(collections.Where(c => c.Type == CollectionType.DefaultUserCollection)); + + await _factory.GetService().DeleteAsync(organization); + } + + [Fact] + public async Task AutoConfirm_WhenUserIsConfirmedMultipleTimes_ThenShouldSuccessAndOnlyConfirmOneUser() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.AutomaticUserConfirmation, + Enabled = true + }); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.OrganizationDataOwnership, + Enabled = true + }); + + await _loginHelper.LoginAsync(_ownerEmail); + + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions(), + OrganizationUserStatusType.Accepted); + + var tenRequests = Enumerable.Range(0, 10) + .Select(_ => _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + })).ToList(); + + var results = await Task.WhenAll(tenRequests); + + Assert.Contains(results, r => r.StatusCode == HttpStatusCode.NoContent); + + var orgUserRepository = _factory.GetService(); + var confirmedUser = await orgUserRepository.GetByIdAsync(organizationUser.Id); + Assert.NotNull(confirmedUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, confirmedUser.Status); + Assert.Equal(testKey, confirmedUser.Key); + + var collections = await _factory.GetService() + .GetManyByUserIdAsync(organizationUser.UserId!.Value); + Assert.NotEmpty(collections); + // validates user only received one default collection + Assert.Single(collections.Where(c => c.Type == CollectionType.DefaultUserCollection)); + + await _factory.GetService().DeleteAsync(organization); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs new file mode 100644 index 0000000000..6645f29eae --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs @@ -0,0 +1,347 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Response; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Providers.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUserControllerBulkRevokeTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + + public OrganizationUserControllerBulkRevokeTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-user-bulk-revoke-test-{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseMonthly, + ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task BulkRevoke_Success() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser1) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + var (_, orgUser2) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser1.Id, orgUser2.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Equal(2, content.Data.Count()); + Assert.All(content.Data, r => Assert.Empty(r.Error)); + + var actualUsers = await organizationUserRepository.GetManyAsync([orgUser1.Id, orgUser2.Id]); + Assert.All(actualUsers, u => Assert.Equal(OrganizationUserStatusType.Revoked, u.Status)); + } + + [Fact] + public async Task BulkRevoke_AsAdmin_Success() + { + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(adminEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.All(content.Data, r => Assert.Empty(r.Error)); + + var actualUser = await _factory.GetService().GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_CannotRevokeSelf_ReturnsError() + { + var (userEmail, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(userEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == orgUser.Id && r.Error == "You cannot revoke yourself."); + + var actualUser = await _factory.GetService().GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_AlreadyRevoked_ReturnsError() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + await organizationUserRepository.RevokeAsync(orgUser.Id); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == orgUser.Id && r.Error == "Already revoked."); + + var actualUser = await organizationUserRepository.GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_AdminCannotRevokeOwner_ReturnsError() + { + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(adminEmail); + + var (_, ownerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.Owner); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [ownerOrgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == ownerOrgUser.Id && r.Error == "Only owners can revoke other owners."); + + var actualUser = await _factory.GetService().GetByIdAsync(ownerOrgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_MixedResults() + { + var (ownerEmail, requestingOwner) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, validOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + var (_, alreadyRevokedOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + await organizationUserRepository.RevokeAsync(alreadyRevokedOrgUser.Id); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [validOrgUser.Id, alreadyRevokedOrgUser.Id, requestingOwner.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Equal(3, content.Data.Count()); + + Assert.Contains(content.Data, r => r.Id == validOrgUser.Id && r.Error == string.Empty); + Assert.Contains(content.Data, r => r.Id == alreadyRevokedOrgUser.Id && r.Error == "Already revoked."); + Assert.Contains(content.Data, r => r.Id == requestingOwner.Id && r.Error == "You cannot revoke yourself."); + + var actualUsers = await organizationUserRepository.GetManyAsync([validOrgUser.Id, alreadyRevokedOrgUser.Id, requestingOwner.Id]); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUsers.First(u => u.Id == validOrgUser.Id).Status); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUsers.First(u => u.Id == alreadyRevokedOrgUser.Id).Status); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUsers.First(u => u.Id == requestingOwner.Id).Status); + } + + [Theory] + [InlineData(OrganizationUserType.User)] + [InlineData(OrganizationUserType.Custom)] + public async Task BulkRevoke_WithoutManageUsersPermission_ReturnsForbidden(OrganizationUserType organizationUserType) + { + var (userEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, organizationUserType, new Permissions { ManageUsers = false }); + + await _loginHelper.LoginAsync(userEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [Guid.NewGuid()] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.Forbidden, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_WithEmptyIds_ReturnsBadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.BadRequest, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_WithInvalidOrganizationId_ReturnsForbidden() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var invalidOrgId = Guid.NewGuid(); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{invalidOrgId}/users/revoke", request); + + Assert.Equal(HttpStatusCode.Forbidden, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_ProviderRevokesOwner_ReturnsOk() + { + var providerEmail = $"provider-user{Guid.NewGuid()}@example.com"; + + // create user for provider + await _factory.LoginWithNewAccount(providerEmail); + + // create provider and provider user + await _factory.GetService() + .CreateBusinessUnitAsync( + new Provider + { + Name = "provider", + Type = ProviderType.BusinessUnit + }, + providerEmail, + PlanType.EnterpriseAnnually2023, + 10); + + await _loginHelper.LoginAsync(providerEmail); + + var providerUserUser = await _factory.GetService().GetByEmailAsync(providerEmail); + + var providerUserCollection = await _factory.GetService() + .GetManyByUserAsync(providerUserUser!.Id); + + var providerUser = providerUserCollection.First(); + + await _factory.GetService().CreateAsync(new ProviderOrganization + { + ProviderId = providerUser.ProviderId, + OrganizationId = _organization.Id, + Key = null, + Settings = null + }); + + var (_, ownerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [ownerOrgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs index 7c61a88bd8..0fef4a0cd0 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs @@ -218,7 +218,7 @@ public class OrganizationUserControllerTests : IClassFixture + /// Tests GET /organizations/{orgId}/users?includeCollections=true + ///
    [Theory(Skip = "Performance test")] - [InlineData(100)] - [InlineData(60000)] - public async Task GetAsync(int seats) + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task GetAllUsers_WithCollections(int seats) { await using var factory = new SqlServerApiApplicationFactory(); var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var seeder = new OrganizationWithUsersRecipe(db); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); - var orgId = seeder.Seed("Org", seats, "large.test"); + var domain = OrganizationTestHelpers.GenerateRandomDomain(); - var tokens = await factory.LoginAsync("admin@large.test", "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="); - client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); + groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); var stopwatch = System.Diagnostics.Stopwatch.StartNew(); var response = await client.GetAsync($"/organizations/{orgId}/users?includeCollections=true"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); - var result = await response.Content.ReadAsStringAsync(); - Assert.NotEmpty(result); + stopwatch.Stop(); + testOutputHelper.WriteLine($"GET /users - Seats: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + } + + /// + /// Tests GET /organizations/{orgId}/users/mini-details + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task GetAllUsers_MiniDetails(int seats) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); + groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/mini-details"); stopwatch.Stop(); - testOutputHelper.WriteLine($"Seed: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + testOutputHelper.WriteLine($"GET /users/mini-details - Seats: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests GET /organizations/{orgId}/users/{id}?includeGroups=true + /// + [Fact(Skip = "Performance test")] + public async Task GetSingleUser_WithGroups() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserId = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).FirstOrDefault(); + groupsSeeder.AddToOrganization(orgId, 2, [orgUserId]); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/{orgUserId}?includeGroups=true"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"GET /users/{{id}} - Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests GET /organizations/{orgId}/users/{id}/reset-password-details + /// + [Fact(Skip = "Performance test")] + public async Task GetResetPasswordDetails_ForSingleUser() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserId = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).FirstOrDefault(); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/{orgUserId}/reset-password-details"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"GET /users/{{id}}/reset-password-details - Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/confirm + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkConfirmUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Accepted); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var acceptedUserIds = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Status == OrganizationUserStatusType.Accepted) + .Select(ou => ou.Id) + .ToList(); + + var confirmRequest = new OrganizationUserBulkConfirmRequestModel + { + Keys = acceptedUserIds.Select(id => new OrganizationUserBulkConfirmRequestModelEntry { Id = id, Key = "test-key-" + id }), + DefaultUserCollectionName = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=" + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(confirmRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/confirm", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/confirm - Users: {acceptedUserIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/remove + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRemoveUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRemove = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var removeRequest = new OrganizationUserBulkRequestModel { Ids = usersToRemove }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var requestContent = new StringContent(JsonSerializer.Serialize(removeRequest), Encoding.UTF8, "application/json"); + + var response = await client.PostAsync($"/organizations/{orgId}/users/remove", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/remove - Users: {usersToRemove.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/revoke + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRevokeUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Confirmed); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRevoke = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var revokeRequest = new OrganizationUserBulkRequestModel { Ids = usersToRevoke }; + + var requestContent = new StringContent(JsonSerializer.Serialize(revokeRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/revoke", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/revoke - Users: {usersToRevoke.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/restore + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRestoreUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Revoked); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRestore = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var restoreRequest = new OrganizationUserBulkRequestModel { Ids = usersToRestore }; + + var requestContent = new StringContent(JsonSerializer.Serialize(restoreRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/restore", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/restore - Users: {usersToRestore.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/delete-account + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkDeleteAccounts(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var domainSeeder = new OrganizationDomainRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Confirmed); + + domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToDelete = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var deleteRequest = new OrganizationUserBulkRequestModel { Ids = usersToDelete }; + + var requestContent = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/delete-account", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/delete-account - Users: {usersToDelete.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/{id} + /// + [Fact(Skip = "Performance test")] + public async Task UpdateSingleUser_WithCollectionsAndGroups() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, 3, orgUserIds, 0); + var groupIds = groupsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var userToUpdate = db.OrganizationUsers + .FirstOrDefault(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User); + + var updateRequest = new OrganizationUserUpdateRequestModel + { + Type = OrganizationUserType.Custom, + Collections = collectionIds.Select(c => new SelectionReadOnlyRequestModel { Id = c, ReadOnly = false, HidePasswords = false, Manage = false }), + Groups = groupIds, + AccessSecretsManager = false, + Permissions = new Permissions { AccessEventLogs = true } + }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/{userToUpdate.Id}", + new StringContent(JsonSerializer.Serialize(updateRequest), Encoding.UTF8, "application/json")); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/{{id}} - Collections: {collectionIds.Count}; Groups: {groupIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/enable-secrets-manager + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkEnableSecretsManager(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToEnable = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var enableRequest = new OrganizationUserBulkRequestModel { Ids = usersToEnable }; + + var requestContent = new StringContent(JsonSerializer.Serialize(enableRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/enable-secrets-manager", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/enable-secrets-manager - Users: {usersToEnable.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests DELETE /organizations/{orgId}/users/{id}/delete-account + /// + [Fact(Skip = "Performance test")] + public async Task DeleteSingleUserAccount_FromVerifiedDomain() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var domainSeeder = new OrganizationDomainRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: 2, + usersStatus: OrganizationUserStatusType.Confirmed); + + domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var userToDelete = db.OrganizationUsers + .FirstOrDefault(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.DeleteAsync($"/organizations/{orgId}/users/{userToDelete.Id}/delete-account"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"DELETE /users/{{id}}/delete-account - Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/invite + /// + [Theory(Skip = "Performance test")] + [InlineData(1)] + //[InlineData(5)] + //[InlineData(20)] + public async Task InviteUsers(int emailCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var emails = Enumerable.Range(0, emailCount).Select(i => $"{i:D4}@{domain}").ToArray(); + var inviteRequest = new OrganizationUserInviteRequestModel + { + Emails = emails, + Type = OrganizationUserType.User, + AccessSecretsManager = false, + Collections = Array.Empty(), + Groups = Array.Empty(), + Permissions = null + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(inviteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/invite", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/invite - Emails: {emails.Length}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/reinvite + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkReinviteUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Invited); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToReinvite = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Status == OrganizationUserStatusType.Invited) + .Select(ou => ou.Id) + .ToList(); + + var reinviteRequest = new OrganizationUserBulkRequestModel { Ids = usersToReinvite }; + + var requestContent = new StringContent(JsonSerializer.Serialize(reinviteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/reinvite", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/reinvite - Users: {usersToReinvite.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); } } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs index cf842d1568..38e3cac863 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs @@ -3,7 +3,6 @@ using Bit.Api.AdminConsole.Authorization; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Api.Models.Request.Organizations; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums; @@ -14,8 +13,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; using Bit.Core.Repositories; -using Bit.Core.Services; -using NSubstitute; using Xunit; namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; @@ -32,12 +29,6 @@ public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture(featureService => - { - featureService - .IsEnabled(FeatureFlagKeys.AccountRecoveryCommand) - .Returns(true); - }); _client = _factory.CreateClient(); _loginHelper = new LoginHelper(_factory, _client); } @@ -47,7 +38,7 @@ public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture + /// Tests DELETE /organizations/{id} with password verification + ///
    + [Theory(Skip = "Performance test")] + [InlineData(10, 5, 3)] + //[InlineData(100, 20, 10)] + //[InlineData(1000, 50, 25)] + public async Task DeleteOrganization_WithPasswordVerification(int userCount, int collectionCount, int groupCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var deleteRequest = new SecretVerificationRequestModel + { + MasterPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs=" + }; + + var request = new HttpRequestMessage(HttpMethod.Delete, $"/organizations/{orgId}") + { + Content = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json") + }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + + var response = await client.SendAsync(request); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"DELETE /organizations/{{id}} - Users: {userCount}; Collections: {collectionCount}; Groups: {groupCount}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{id}/delete-recover-token with token verification + /// + [Theory(Skip = "Performance test")] + [InlineData(10, 5, 3)] + //[InlineData(100, 20, 10)] + //[InlineData(1000, 50, 25)] + public async Task DeleteOrganization_WithTokenVerification(int userCount, int collectionCount, int groupCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var organization = db.Organizations.FirstOrDefault(o => o.Id == orgId); + Assert.NotNull(organization); + + var tokenFactory = factory.GetService>(); + var tokenable = new OrgDeleteTokenable(organization, 24); + var token = tokenFactory.Protect(tokenable); + + var deleteRequest = new OrganizationVerifyDeleteRecoverRequestModel + { + Token = token + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/delete-recover-token", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /organizations/{{id}}/delete-recover-token - Users: {userCount}; Collections: {collectionCount}; Groups: {groupCount}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/create-without-payment + /// + [Fact(Skip = "Performance test")] + public async Task CreateOrganization_WithoutPayment() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var email = $"user@{OrganizationTestHelpers.GenerateRandomDomain()}"; + var masterPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="; + + await factory.LoginWithNewAccount(email, masterPasswordHash); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, email, masterPasswordHash); + + var createRequest = new OrganizationNoPaymentCreateRequest + { + Name = "Test Organization", + BusinessName = "Test Business Name", + BillingEmail = email, + PlanType = PlanType.EnterpriseAnnually, + Key = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=", + AdditionalSeats = 1, + AdditionalStorageGb = 1, + UseSecretsManager = true, + AdditionalSmSeats = 1, + AdditionalServiceAccounts = 2, + MaxAutoscaleSeats = 100, + PremiumAccessAddon = false, + CollectionName = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=" + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(createRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync("/organizations/create-without-payment", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /organizations/create-without-payment - AdditionalSeats: {createRequest.AdditionalSeats}; AdditionalStorageGb: {createRequest.AdditionalStorageGb}; AdditionalSmSeats: {createRequest.AdditionalSmSeats}; AdditionalServiceAccounts: {createRequest.AdditionalServiceAccounts}; MaxAutoscaleSeats: {createRequest.MaxAutoscaleSeats}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs new file mode 100644 index 0000000000..c234e77bc8 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -0,0 +1,196 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationsControllerTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + private readonly string _billingEmail = "billing@example.com"; + private readonly string _organizationName = "Organizations Controller Test Org"; + + public OrganizationsControllerTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-integration-test-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, + name: _organizationName, + billingEmail: _billingEmail, + plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, + passwordManagerSeats: 5, + paymentMethod: PaymentMethodType.Card); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task Put_AsOwner_WithoutProvider_CanUpdateOrganization() + { + // Arrange - Regular organization owner (no provider) + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal("newbillingemail@example.com", updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsProvider_CanUpdateOrganization() + { + // Create and login as a new account to be the provider user (not the owner) + var providerUserEmail = $"provider-{Guid.NewGuid()}@example.com"; + var (token, _) = await _factory.LoginWithNewAccount(providerUserEmail); + + // Set up provider linked to org and ProviderUser entry + var provider = await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, + ProviderType.Msp); + await ProviderTestHelpers.CreateProviderUserAsync(_factory, provider.Id, providerUserEmail, + ProviderUserType.ProviderAdmin); + + await _loginHelper.LoginAsync(providerUserEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal("newbillingemail@example.com", updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_NotMemberOrProvider_CannotUpdateOrganization() + { + // Create and login as a new account to be unrelated to the org + var userEmail = "stranger@example.com"; + await _factory.LoginWithNewAccount(userEmail); + await _loginHelper.LoginAsync(userEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + + // Verify the organization name was not updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal(_organizationName, updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsOwner_WithProvider_CanRenameOrganization() + { + // Arrange - Create provider and link to organization + // The active user is ONLY an org owner, NOT a provider user + await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, ProviderType.Msp); + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = null + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was actually updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsOwner_WithProvider_CannotChangeBillingEmail() + { + // Arrange - Create provider and link to organization + // The active user is ONLY an org owner, NOT a provider user + await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, ProviderType.Msp); + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "updatedbilling@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + + // Verify the organization was not updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal(_organizationName, updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs b/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs index 32c7f75a2b..6ba65f6453 100644 --- a/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs @@ -33,7 +33,7 @@ public class ImportOrganizationUsersAndGroupsCommandTests : IClassFixture, IAsy await _factory.LoginWithNewAccount(_ownerEmail); // Create the organization - (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); // Authorize with the organization api key diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index 0b5ab660b9..6144d7eebb 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -39,7 +39,7 @@ public class PoliciesControllerTests : IClassFixture, IAs await _factory.LoginWithNewAccount(_ownerEmail); // Create the organization - (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); // Authorize with the organization api key diff --git a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs index c23ebff736..bcde370b24 100644 --- a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs +++ b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs @@ -192,6 +192,15 @@ public static class OrganizationTestHelpers await policyRepository.CreateAsync(policy); } + /// + /// Generates a unique random domain name for testing purposes. + /// + /// A domain string like "a1b2c3d4.com" + public static string GenerateRandomDomain() + { + return $"{Guid.NewGuid().ToString("N").Substring(0, 8)}.com"; + } + /// /// Creates a user account without a Master Password and adds them as a member to the specified organization. /// diff --git a/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs b/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs new file mode 100644 index 0000000000..ca26266dfa --- /dev/null +++ b/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs @@ -0,0 +1,32 @@ +using System.Net.Http.Headers; +using Bit.Api.IntegrationTest.Factories; + +namespace Bit.Api.IntegrationTest.Helpers; + +/// +/// Helper methods for performance tests to reduce code duplication. +/// +public static class PerformanceTestHelpers +{ + /// + /// Standard password hash used across performance tests. + /// + public const string StandardPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="; + + /// + /// Authenticates an HttpClient with a bearer token for the specified user. + /// + /// The application factory to use for login. + /// The HttpClient to authenticate. + /// The user's email address. + /// The user's master password hash. Defaults to StandardPasswordHash. + public static async Task AuthenticateClientAsync( + SqlServerApiApplicationFactory factory, + HttpClient client, + string email, + string? masterPasswordHash = null) + { + var tokens = await factory.LoginAsync(email, masterPasswordHash ?? StandardPasswordHash); + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } +} diff --git a/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs b/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs new file mode 100644 index 0000000000..ab52bcd076 --- /dev/null +++ b/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs @@ -0,0 +1,77 @@ +using Bit.Api.IntegrationTest.Factories; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Repositories; + +namespace Bit.Api.IntegrationTest.Helpers; + +public static class ProviderTestHelpers +{ + /// + /// Creates a provider and links it to an organization. + /// This does NOT create any provider users. + /// + /// The API application factory + /// The organization ID to link to the provider + /// The type of provider to create + /// The provider status (defaults to Created) + /// The created provider + public static async Task CreateProviderAndLinkToOrganizationAsync( + ApiApplicationFactory factory, + Guid organizationId, + ProviderType providerType, + ProviderStatusType providerStatus = ProviderStatusType.Created) + { + var providerRepository = factory.GetService(); + var providerOrganizationRepository = factory.GetService(); + + // Create the provider + var provider = await providerRepository.CreateAsync(new Provider + { + Name = $"Test {providerType} Provider", + BusinessName = $"Test {providerType} Provider Business", + BillingEmail = $"provider-{providerType.ToString().ToLower()}@example.com", + Type = providerType, + Status = providerStatus, + Enabled = true + }); + + // Link the provider to the organization + await providerOrganizationRepository.CreateAsync(new ProviderOrganization + { + ProviderId = provider.Id, + OrganizationId = organizationId, + Key = "test-provider-key" + }); + + return provider; + } + + /// + /// Creates a providerUser for a provider. + /// + public static async Task CreateProviderUserAsync( + ApiApplicationFactory factory, + Guid providerId, + string userEmail, + ProviderUserType providerUserType) + { + var userRepository = factory.GetService(); + var user = await userRepository.GetByEmailAsync(userEmail); + if (user is null) + { + throw new Exception("No user found in test setup."); + } + + var providerUserRepository = factory.GetService(); + return await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = providerId, + Status = ProviderUserStatusType.Confirmed, + UserId = user.Id, + Key = Guid.NewGuid().ToString(), + Type = providerUserType + }); + } +} diff --git a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index 1630bc0dc0..1c456df106 100644 --- a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -3,9 +3,11 @@ using System.Net; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Api.KeyManagement.Models.Requests; +using Bit.Api.KeyManagement.Models.Responses; using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models; using Bit.Api.Vault.Models.Request; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Api.Request.Accounts; @@ -286,20 +288,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); + + Assert.NotNull(result); + Assert.Equal(organization.Name, result.OrganizationName); + } + + private async Task<(string, Organization)> SetupKeyConnectorTestAsync(OrganizationUserStatusType userStatusType, + string organizationSsoIdentifier = "test-sso-identifier") + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, + PlanType.EnterpriseAnnually, _ownerEmail, passwordManagerSeats: 10, + paymentMethod: PaymentMethodType.Card); + organization.UseKeyConnector = true; + organization.UseSso = true; + organization.Identifier = organizationSsoIdentifier; + await _organizationRepository.ReplaceAsync(organization); + + var ssoUserEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(ssoUserEmail); + await _loginHelper.LoginAsync(ssoUserEmail); + + await OrganizationTestHelpers.CreateUserAsync(_factory, organization.Id, ssoUserEmail, + OrganizationUserType.User, userStatusType: userStatusType); + + return (ssoUserEmail, organization); + } } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretVersionsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretVersionsControllerTests.cs new file mode 100644 index 0000000000..9393795e55 --- /dev/null +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretVersionsControllerTests.cs @@ -0,0 +1,289 @@ +using System.Net; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.SecretsManager.Enums; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; +using Bit.Api.Models.Response; +using Bit.Api.SecretsManager.Models.Request; +using Bit.Api.SecretsManager.Models.Response; +using Bit.Core.Enums; +using Bit.Core.SecretsManager.Entities; +using Bit.Core.SecretsManager.Repositories; +using Xunit; + +namespace Bit.Api.IntegrationTest.SecretsManager.Controllers; + +public class SecretVersionsControllerTests : IClassFixture, IAsyncLifetime +{ + private readonly string _mockEncryptedString = + "2.3Uk+WNBIoU5xzmVFNcoWzz==|1MsPIYuRfdOHfu/0uY6H2Q==|/98sp4wb6pHP1VTZ9JcNCYgQjEUMFPlqJgCwRk1YXKg="; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly ISecretRepository _secretRepository; + private readonly ISecretVersionRepository _secretVersionRepository; + private readonly IAccessPolicyRepository _accessPolicyRepository; + private readonly LoginHelper _loginHelper; + + private string _email = null!; + private SecretsManagerOrganizationHelper _organizationHelper = null!; + + public SecretVersionsControllerTests(ApiApplicationFactory factory) + { + _factory = factory; + _client = _factory.CreateClient(); + _secretRepository = _factory.GetService(); + _secretVersionRepository = _factory.GetService(); + _accessPolicyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _email = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_email); + _organizationHelper = new SecretsManagerOrganizationHelper(_factory, _email); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Theory] + [InlineData(false, false, false)] + [InlineData(false, false, true)] + [InlineData(false, true, false)] + [InlineData(false, true, true)] + [InlineData(true, false, false)] + [InlineData(true, false, true)] + [InlineData(true, true, false)] + public async Task GetVersionsBySecretId_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) + { + var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Theory] + [InlineData(PermissionType.RunAsAdmin)] + [InlineData(PermissionType.RunAsUserWithPermission)] + public async Task GetVersionsBySecretId_Success(PermissionType permissionType) + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + // Create some versions + var version1 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var version2 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + if (permissionType == PermissionType.RunAsUserWithPermission) + { + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await _loginHelper.LoginAsync(email); + + var accessPolicies = new List + { + new UserSecretAccessPolicy + { + GrantedSecretId = secret.Id, + OrganizationUserId = orgUser.Id, + Read = true, + Write = true + } + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + } + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync>(); + + Assert.NotNull(result); + Assert.Equal(2, result.Data.Count()); + } + + [Fact] + public async Task GetVersionById_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var version = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow + }); + + var response = await _client.GetAsync($"/secret-versions/{version.Id}"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(); + + Assert.NotNull(result); + Assert.Equal(version.Id, result.Id); + Assert.Equal(secret.Id, result.SecretId); + } + + [Fact] + public async Task RestoreVersion_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = "OriginalValue", + Note = _mockEncryptedString + }); + + var version = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "OldValue", + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + var request = new RestoreSecretVersionRequestModel + { + VersionId = version.Id + }; + + var response = await _client.PutAsJsonAsync($"/secrets/{secret.Id}/versions/restore", request); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(); + + Assert.NotNull(result); + Assert.Equal("OldValue", result.Value); + } + + [Fact] + public async Task BulkDelete_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var version1 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var version2 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + var ids = new List { version1.Id, version2.Id }; + + var response = await _client.PostAsJsonAsync("/secret-versions/delete", ids); + response.EnsureSuccessStatusCode(); + + var versions = await _secretVersionRepository.GetManyBySecretIdAsync(secret.Id); + Assert.Empty(versions); + } + + [Fact] + public async Task GetVersionsBySecretId_ReturnsOrderedByVersionDate() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + // Create versions in random order + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version2", + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version3", + VersionDate = DateTime.UtcNow + }); + + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version1", + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync>(); + + Assert.NotNull(result); + Assert.Equal(3, result.Data.Count()); + + var versions = result.Data.ToList(); + // Should be ordered by VersionDate descending (newest first) + Assert.Equal("Version3", versions[0].Value); + Assert.Equal("Version2", versions[1].Value); + Assert.Equal("Version1", versions[2].Value); + } +} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs index 335859e0c4..c9131f3505 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs @@ -2,15 +2,14 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; using NSubstitute; -using NSubstitute.ReturnsExtensions; using Xunit; namespace Bit.Api.Test.AdminConsole.Controllers; @@ -19,7 +18,7 @@ namespace Bit.Api.Test.AdminConsole.Controllers; [SutProviderCustomize] public class OrganizationIntegrationControllerTests { - private OrganizationIntegrationRequestModel _webhookRequestModel = new OrganizationIntegrationRequestModel() + private readonly OrganizationIntegrationRequestModel _webhookRequestModel = new() { Configuration = null, Type = IntegrationType.Webhook @@ -48,13 +47,13 @@ public class OrganizationIntegrationControllerTests sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() + sutProvider.GetDependency() .GetManyByOrganizationAsync(organizationId) .Returns(integrations); var result = await sutProvider.Sut.GetAsync(organizationId); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .GetManyByOrganizationAsync(organizationId); Assert.Equal(integrations.Count, result.Count); @@ -70,7 +69,7 @@ public class OrganizationIntegrationControllerTests sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() + sutProvider.GetDependency() .GetManyByOrganizationAsync(organizationId) .Returns([]); @@ -80,199 +79,133 @@ public class OrganizationIntegrationControllerTests } [Theory, BitAutoData] - public async Task CreateAsync_Webhook_AllParamsProvided_Succeeds( + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + + var response = await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(i => + i.OrganizationId == organizationId && + i.Type == IntegrationType.Webhook)); + Assert.IsType(response); + } + + [Theory, BitAutoData] + public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(callInfo => callInfo.Arg()); - var response = await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(response); - Assert.Equal(IntegrationType.Webhook, response.Type); - } - - [Theory, BitAutoData] - public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel)); } [Theory, BitAutoData] public async Task DeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration) + Guid integrationId) { - organizationIntegration.OrganizationId = organizationId; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id); + await sutProvider.Sut.DeleteAsync(organizationId, integrationId); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegration); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId); } [Theory, BitAutoData] + [Obsolete("Obsolete")] public async Task PostDeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegration); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = Guid.NewGuid(); - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); + await sutProvider.Sut.PostDeleteAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId); } [Theory, BitAutoData] public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.DeleteAsync(organizationId, integrationId)); } [Theory, BitAutoData] public async Task UpdateAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration) + Guid integrationId, + OrganizationIntegration integration) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; + integration.OrganizationId = organizationId; + integration.Id = integrationId; + integration.Type = IntegrationType.Webhook; + sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); + sutProvider.GetDependency() + .UpdateAsync(organizationId, integrationId, Arg.Any()) + .Returns(integration); - var response = await sutProvider.Sut.UpdateAsync(organizationId, organizationIntegration.Id, _webhookRequestModel); + var response = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, _webhookRequestModel); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(organizationIntegration); + await sutProvider.GetDependency().Received(1) + .UpdateAsync(organizationId, integrationId, Arg.Is(i => + i.OrganizationId == organizationId && + i.Type == IntegrationType.Webhook)); Assert.IsType(response); Assert.Equal(IntegrationType.Webhook, response.Type); } - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = Guid.NewGuid(); - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); - } - [Theory, BitAutoData] public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, _webhookRequestModel)); } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs index 9ab626d3f0..6e1dadb92f 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs @@ -1,18 +1,14 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Controllers; +using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; using Bit.Core.Context; -using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; using NSubstitute; -using NSubstitute.ReturnsExtensions; using Xunit; namespace Bit.Api.Test.AdminConsole.Controllers; @@ -25,823 +21,191 @@ public class OrganizationIntegrationsConfigurationControllerTests public async Task DeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) + Guid integrationId, + Guid configurationId) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegrationConfiguration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegrationConfiguration); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); } [Theory, BitAutoData] + [Obsolete("Obsolete")] public async Task PostDeleteAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegrationConfiguration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegrationConfiguration); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } + await sutProvider.Sut.PostDeleteAsync(organizationId, integrationId, configurationId); - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigDoesNotBelongToIntegration_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = Guid.Empty; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); } [Theory, BitAutoData] public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); } [Theory, BitAutoData] public async Task GetAsync_ConfigurationsExist_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - List organizationIntegrationConfigurations) + Guid integrationId, + List configurations) { - organizationIntegration.OrganizationId = organizationId; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) - .Returns(organizationIntegrationConfigurations); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); Assert.NotNull(result); - Assert.Equal(organizationIntegrationConfigurations.Count, result.Count); + Assert.Equal(configurations.Count, result.Count); Assert.All(result, r => Assert.IsType(r)); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); } [Theory, BitAutoData] public async Task GetAsync_NoConfigurationsExist_ReturnsEmptyList( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration) + Guid integrationId) { - organizationIntegration.OrganizationId = organizationId; sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) .Returns([]); - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); + Assert.NotNull(result); Assert.Empty(result); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id)); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); } [Theory, BitAutoData] public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, - Guid organizationId) + Guid organizationId, + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.GetAsync(organizationId, integrationId)); } [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Slack_Succeeds( + public async Task PostAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, + Guid integrationId, + OrganizationIntegrationConfiguration configuration, OrganizationIntegrationConfigurationRequestModel model) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); + sutProvider.GetDependency() + .CreateAsync(organizationId, integrationId, Arg.Any()) + .Returns(configuration); - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, integrationId, model); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(organizationId, integrationId, Arg.Any()); Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); } [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Webhook_Succeeds( + public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(createResponse); - Assert.Equal(expected.Id, createResponse.Id); - Assert.Equal(expected.Configuration, createResponse.Configuration); - Assert.Equal(expected.EventType, createResponse.EventType); - Assert.Equal(expected.Filters, createResponse.Filters); - Assert.Equal(expected.Template, createResponse.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeCloudBillingSync_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.CloudBillingSync; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeScim_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Scim; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + Guid integrationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, Guid.Empty, new OrganizationIntegrationConfigurationRequestModel())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(organizationId, integrationId, new OrganizationIntegrationConfigurationRequestModel())); } [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Slack_Succeeds( + public async Task UpdateAsync_AllParamsProvided_Succeeds( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration configuration, OrganizationIntegrationConfigurationRequestModel model) { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); + sutProvider.GetDependency() + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()) + .Returns(configuration); - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); + var updateResponse = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, model); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()); Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); } - [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Webhook_Succeeds( + public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( SutProvider sutProvider, Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var updateResponse = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(updateResponse); - Assert.Equal(expected.Id, updateResponse.Id); - Assert.Equal(expected.Configuration, updateResponse.Configuration); - Assert.Equal(expected.EventType, updateResponse.EventType); - Assert.Equal(expected.Filters, updateResponse.Filters); - Assert.Equal(expected.Template, updateResponse.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + Guid integrationId, + Guid configurationId) { sutProvider.Sut.Url = Substitute.For(); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, new OrganizationIntegrationConfigurationRequestModel())); } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs index 5875cda05a..43f0123a3f 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs @@ -9,10 +9,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Results; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Repositories; using Bit.Core.Context; @@ -33,9 +36,11 @@ using Bit.Test.Common.AutoFixture.Attributes; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc.ModelBinding; using NSubstitute; +using OneOf.Types; using Xunit; namespace Bit.Api.Test.AdminConsole.Controllers; @@ -448,90 +453,38 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagDisabled_CallsLegacyPath( + public async Task PutResetPassword_WhenOrganizationUserNotFound_ReturnsNotFound( Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); - sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); - sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) - .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); - - var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); - - Assert.IsType(result); - await sutProvider.GetDependency().Received(1) - .AdminResetPasswordAsync(OrganizationUserType.Owner, orgId, orgUserId, model.NewMasterPasswordHash, model.Key); - } - - [Theory] - [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagDisabled_WhenOrgUserTypeIsNull_ReturnsNotFound( - Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); - sutProvider.GetDependency().OrganizationOwner(orgId).Returns(false); - sutProvider.GetDependency().Organizations.Returns(new List()); - - var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); - - Assert.IsType(result); - } - - [Theory] - [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagDisabled_WhenAdminResetPasswordFails_ReturnsBadRequest( - Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); - sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); - sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) - .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error 1" })); - - var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); - - Assert.IsType>(result); - } - - [Theory] - [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationUserNotFound_ReturnsNotFound( - Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns((OrganizationUser)null); var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); - Assert.IsType(result); + Assert.IsType(result); } [Theory] [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationIdMismatch_ReturnsNotFound( + public async Task PutResetPassword_WhenOrganizationIdMismatch_ReturnsNotFound( Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, SutProvider sutProvider) { organizationUser.OrganizationId = Guid.NewGuid(); - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); - Assert.IsType(result); + Assert.IsType(result); } [Theory] [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagEnabled_WhenAuthorizationFails_ReturnsBadRequest( + public async Task PutResetPassword_WhenAuthorizationFails_ReturnsBadRequest( Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, SutProvider sutProvider) { organizationUser.OrganizationId = orgId; - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); sutProvider.GetDependency() .AuthorizeAsync( @@ -547,12 +500,11 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountSucceeds_ReturnsOk( + public async Task PutResetPassword_WhenRecoverAccountSucceeds_ReturnsOk( Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, SutProvider sutProvider) { organizationUser.OrganizationId = orgId; - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); sutProvider.GetDependency() .AuthorizeAsync( @@ -573,12 +525,11 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] - public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountFails_ReturnsBadRequest( + public async Task PutResetPassword_WhenRecoverAccountFails_ReturnsBadRequest( Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, SutProvider sutProvider) { organizationUser.OrganizationId = orgId; - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); sutProvider.GetDependency() .AuthorizeAsync( @@ -594,4 +545,254 @@ public class OrganizationUsersControllerTests Assert.IsType>(result); } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_UserIdNull_ReturnsUnauthorized( + Guid orgId, + Guid orgUserId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns((Guid?)null); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_UserIdEmpty_ReturnsUnauthorized( + Guid orgId, + Guid orgUserId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(Guid.Empty); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_Success_ReturnsOk( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(new None())); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_NotFoundError_ReturnsNotFound( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(false); + + var notFoundError = new OrganizationNotFound(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(notFoundError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var notFoundResult = Assert.IsType>(result); + Assert.Equal(notFoundError.Message, notFoundResult.Value.Message); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_BadRequestError_ReturnsBadRequest( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + var badRequestError = new UserIsNotAccepted(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(badRequestError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var badRequestResult = Assert.IsType>(result); + Assert.Equal(badRequestError.Message, badRequestResult.Value.Message); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_InternalError_ReturnsProblem( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + var internalError = new FailedToWriteToEventLog(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(internalError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var problemResult = Assert.IsType>(result); + Assert.Equal(StatusCodes.Status500InternalServerError, problemResult.StatusCode); + } + + [Theory] + [BitAutoData] + public async Task BulkReinvite_WhenFeatureFlagEnabled_UsesBulkResendOrganizationInvitesCommand( + Guid organizationId, + OrganizationUserBulkRequestModel bulkRequestModel, + List organizationUsers, + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) + .Returns(true); + + var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); + sutProvider.GetDependency() + .BulkResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids) + .Returns(expectedResults); + + // Act + var response = await sutProvider.Sut.BulkReinvite(organizationId, bulkRequestModel); + + // Assert + Assert.Equal(organizationUsers.Count, response.Data.Count()); + + await sutProvider.GetDependency() + .Received(1) + .BulkResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); + } + + [Theory] + [BitAutoData] + public async Task BulkReinvite_WhenFeatureFlagDisabled_UsesLegacyOrganizationService( + Guid organizationId, + OrganizationUserBulkRequestModel bulkRequestModel, + List organizationUsers, + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) + .Returns(false); + + var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); + sutProvider.GetDependency() + .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids) + .Returns(expectedResults); + + // Act + var response = await sutProvider.Sut.BulkReinvite(organizationId, bulkRequestModel); + + // Assert + Assert.Equal(organizationUsers.Count, response.Data.Count()); + + await sutProvider.GetDependency() + .Received(1) + .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); + } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index 00fd3c3b4e..d87f035a13 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -1,5 +1,4 @@ using System.Security.Claims; -using AutoFixture.Xunit2; using Bit.Api.AdminConsole.Controllers; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Models.Request.Organizations; @@ -8,9 +7,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business; -using Bit.Core.AdminConsole.Models.Business.Tokenables; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; @@ -20,7 +16,6 @@ using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; -using Bit.Core.Auth.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; @@ -30,102 +25,24 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Tokens; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -using GlobalSettings = Bit.Core.Settings.GlobalSettings; namespace Bit.Api.Test.AdminConsole.Controllers; -public class OrganizationsControllerTests : IDisposable +[ControllerCustomize(typeof(OrganizationsController))] +[SutProviderCustomize] +public class OrganizationsControllerTests { - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IUserService _userService; - private readonly IGetOrganizationApiKeyQuery _getOrganizationApiKeyQuery; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly ICreateOrganizationApiKeyCommand _createOrganizationApiKeyCommand; - private readonly IFeatureService _featureService; - private readonly IProviderRepository _providerRepository; - private readonly IProviderBillingService _providerBillingService; - private readonly IDataProtectorTokenFactory _orgDeleteTokenDataFactory; - private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; - private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand; - private readonly IOrganizationDeleteCommand _organizationDeleteCommand; - private readonly IPolicyRequirementQuery _policyRequirementQuery; - private readonly IPricingClient _pricingClient; - private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand; - private readonly OrganizationsController _sut; - - public OrganizationsControllerTests() - { - _currentContext = Substitute.For(); - _globalSettings = Substitute.For(); - _organizationRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _policyRepository = Substitute.For(); - _ssoConfigRepository = Substitute.For(); - _ssoConfigService = Substitute.For(); - _getOrganizationApiKeyQuery = Substitute.For(); - _rotateOrganizationApiKeyCommand = Substitute.For(); - _organizationApiKeyRepository = Substitute.For(); - _userService = Substitute.For(); - _createOrganizationApiKeyCommand = Substitute.For(); - _featureService = Substitute.For(); - _providerRepository = Substitute.For(); - _providerBillingService = Substitute.For(); - _orgDeleteTokenDataFactory = Substitute.For>(); - _removeOrganizationUserCommand = Substitute.For(); - _cloudOrganizationSignUpCommand = Substitute.For(); - _organizationDeleteCommand = Substitute.For(); - _policyRequirementQuery = Substitute.For(); - _pricingClient = Substitute.For(); - _organizationUpdateKeysCommand = Substitute.For(); - - _sut = new OrganizationsController( - _organizationRepository, - _organizationUserRepository, - _policyRepository, - _organizationService, - _userService, - _currentContext, - _ssoConfigRepository, - _ssoConfigService, - _getOrganizationApiKeyQuery, - _rotateOrganizationApiKeyCommand, - _createOrganizationApiKeyCommand, - _organizationApiKeyRepository, - _featureService, - _globalSettings, - _providerRepository, - _providerBillingService, - _orgDeleteTokenDataFactory, - _removeOrganizationUserCommand, - _cloudOrganizationSignUpCommand, - _organizationDeleteCommand, - _policyRequirementQuery, - _pricingClient, - _organizationUpdateKeysCommand); - } - - public void Dispose() - { - _sut?.Dispose(); - } - - [Theory, AutoData] + [Theory, BitAutoData] public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( - Guid orgId, User user) + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -140,21 +57,24 @@ public class OrganizationsControllerTests : IDisposable user.UsesKeyConnector = true; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { null }); - var exception = await Assert.ThrowsAsync(() => _sut.Leave(orgId)); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { null }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Leave(orgId)); Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", exception.Message); - await _removeOrganizationUserCommand.DidNotReceiveWithAnyArgs().UserLeaveAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UserLeaveAsync(default, default); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task OrganizationsController_UserCannotLeaveOrganizationThatManagesUser( - Guid orgId, User user) + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -166,27 +86,34 @@ public class OrganizationsControllerTests : IDisposable Enabled = true, OrganizationId = orgId, }; - var foundOrg = new Organization(); - foundOrg.Id = orgId; + var foundOrg = new Organization + { + Id = orgId + }; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { { foundOrg } }); - var exception = await Assert.ThrowsAsync(() => _sut.Leave(orgId)); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { foundOrg }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Leave(orgId)); Assert.Contains("Claimed user account cannot leave claiming organization. Contact your organization administrator for additional details.", exception.Message); - await _removeOrganizationUserCommand.DidNotReceiveWithAnyArgs().RemoveUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().RemoveUserAsync(default, default); } [Theory] - [InlineAutoData(true, false)] - [InlineAutoData(false, true)] - [InlineAutoData(false, false)] + [BitAutoData(true, false)] + [BitAutoData(false, true)] + [BitAutoData(false, false)] public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( - bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) + bool keyConnectorEnabled, + bool userUsesKeyConnector, + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -203,18 +130,19 @@ public class OrganizationsControllerTests : IDisposable user.UsesKeyConnector = userUsesKeyConnector; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List()); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List()); - await _sut.Leave(orgId); + await sutProvider.Sut.Leave(orgId); - await _removeOrganizationUserCommand.Received(1).UserLeaveAsync(orgId, user.Id); + await sutProvider.GetDependency().Received(1).UserLeaveAsync(orgId, user.Id); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task Delete_OrganizationIsConsolidatedBillingClient_ScalesProvidersSeats( + SutProvider sutProvider, Provider provider, Organization organization, User user, @@ -228,87 +156,89 @@ public class OrganizationsControllerTests : IDisposable provider.Type = ProviderType.Msp; provider.Status = ProviderStatusType.Billable; - _currentContext.OrganizationOwner(organizationId).Returns(true); + sutProvider.GetDependency().OrganizationOwner(organizationId).Returns(true); + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().VerifySecretAsync(user, requestModel.Secret).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); - _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + await sutProvider.Sut.Delete(organizationId.ToString(), requestModel); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - - _userService.VerifySecretAsync(user, requestModel.Secret).Returns(true); - - _providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider); - - await _sut.Delete(organizationId.ToString(), requestModel); - - await _providerBillingService.Received(1) + await sutProvider.GetDependency().Received(1) .ScaleSeats(provider, organization.PlanType, -organization.Seats.Value); - await _organizationDeleteCommand.Received(1).DeleteAsync(organization); + await sutProvider.GetDependency().Received(1).DeleteAsync(organization); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task GetAutoEnrollStatus_WithPolicyRequirementsEnabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue( + SutProvider sutProvider, User user, Organization organization, - OrganizationUser organizationUser - ) + OrganizationUser organizationUser) { - var policyRequirement = new ResetPasswordPolicyRequirement() { AutoEnrollOrganizations = [organization.Id] }; + var policyRequirement = new ResetPasswordPolicyRequirement { AutoEnrollOrganizations = [organization.Id] }; - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); - _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); - _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); - _policyRequirementQuery.GetAsync(user.Id).Returns(policyRequirement); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); + sutProvider.GetDependency().GetAsync(user.Id).Returns(policyRequirement); - var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString()); + var result = await sutProvider.Sut.GetAutoEnrollStatus(organization.Id.ToString()); - await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any()); - await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString()); - await _policyRequirementQuery.Received(1).GetAsync(user.Id); + await sutProvider.GetDependency().Received(1).GetUserByPrincipalAsync(Arg.Any()); + await sutProvider.GetDependency().Received(1).GetByIdentifierAsync(organization.Id.ToString()); + await sutProvider.GetDependency().Received(1).GetAsync(user.Id); Assert.True(result.ResetPasswordEnabled); Assert.Equal(result.Id, organization.Id); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task GetAutoEnrollStatus_WithPolicyRequirementsDisabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue( - User user, - Organization organization, - OrganizationUser organizationUser -) + SutProvider sutProvider, + User user, + Organization organization, + OrganizationUser organizationUser) { + var policy = new Policy + { + Type = PolicyType.ResetPassword, + Enabled = true, + Data = "{\"AutoEnrollEnabled\": true}", + OrganizationId = organization.Id + }; - var policy = new Policy() { Type = PolicyType.ResetPassword, Enabled = true, Data = "{\"AutoEnrollEnabled\": true}", OrganizationId = organization.Id }; + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); - _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); - _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); - _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); + var result = await sutProvider.Sut.GetAutoEnrollStatus(organization.Id.ToString()); - var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString()); - - await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any()); - await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString()); - await _policyRequirementQuery.Received(0).GetAsync(user.Id); - await _policyRepository.Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + await sutProvider.GetDependency().Received(1).GetUserByPrincipalAsync(Arg.Any()); + await sutProvider.GetDependency().Received(1).GetByIdentifierAsync(organization.Id.ToString()); + await sutProvider.GetDependency().Received(0).GetAsync(user.Id); + await sutProvider.GetDependency().Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); Assert.True(result.ResetPasswordEnabled); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task PutCollectionManagement_ValidRequest_Success( + SutProvider sutProvider, Organization organization, OrganizationCollectionManagementUpdateRequestModel model) { // Arrange - _currentContext.OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - var plan = StaticStore.GetPlan(PlanType.EnterpriseAnnually); - _pricingClient.GetPlan(Arg.Any()).Returns(plan); + var plan = MockPlans.Get(PlanType.EnterpriseAnnually); + sutProvider.GetDependency().GetPlan(Arg.Any()).Returns(plan); - _organizationService + sutProvider.GetDependency() .UpdateCollectionManagementSettingsAsync( organization.Id, Arg.Is(s => @@ -319,10 +249,10 @@ public class OrganizationsControllerTests : IDisposable .Returns(organization); // Act - await _sut.PutCollectionManagement(organization.Id, model); + await sutProvider.Sut.PutCollectionManagement(organization.Id, model); // Assert - await _organizationService + await sutProvider.GetDependency() .Received(1) .UpdateCollectionManagementSettingsAsync( organization.Id, diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs deleted file mode 100644 index 8a75db9da8..0000000000 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs +++ /dev/null @@ -1,248 +0,0 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations; - -public class OrganizationIntegrationConfigurationRequestModelTests -{ - [Fact] - public void IsValidForType_CloudBillingSyncIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.CloudBillingSync)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Theory] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Fact] - public void IsValidForType_NullConfiguration_ReturnsTrue() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = null, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.True(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.True(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyTemplate_ReturnsFalse(string? template) - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = template - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - [Fact] - public void IsValidForType_InvalidJsonConfiguration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{not valid json}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); - } - - - [Fact] - public void IsValidForType_InvalidJsonFilters_ReturnsFalse() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = "{Not valid json", - Template = "template" - }; - - Assert.False(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ScimIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Scim)); - } - - [Fact] - public void IsValidForType_ValidSlackConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new SlackIntegrationConfiguration(ChannelId: "C12345")); - - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidSlackConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidNoAuthWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( - Uri: new Uri("https://example.com"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_UnknownIntegrationType_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - var unknownType = (IntegrationType)999; - - Assert.False(condition: model.IsValidForType(unknownType)); - } -} diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs index c2893c9fce..30b0ccc272 100644 --- a/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs @@ -48,6 +48,7 @@ public class ProfileOrganizationResponseModelTests UsersGetPremium = organization.UsersGetPremium, UseCustomPermissions = organization.UseCustomPermissions, UseRiskInsights = organization.UseRiskInsights, + UsePhishingBlocker = organization.UsePhishingBlocker, UseOrganizationDomains = organization.UseOrganizationDomains, UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs index a131f90724..1757f9d983 100644 --- a/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs @@ -45,6 +45,7 @@ public class ProfileProviderOrganizationResponseModelTests UsersGetPremium = organization.UsersGetPremium, UseCustomPermissions = organization.UseCustomPermissions, UseRiskInsights = organization.UseRiskInsights, + UsePhishingBlocker = organization.UsePhishingBlocker, UseOrganizationDomains = organization.UseOrganizationDomains, UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, diff --git a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index c2360f5f9a..bd10eab617 100644 --- a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -1,14 +1,11 @@ using Bit.Api.AdminConsole.Public.Controllers; using Bit.Api.AdminConsole.Public.Models.Request; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -22,7 +19,7 @@ public class PoliciesControllerTests { [Theory] [BitAutoData] - public async Task Put_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand( + public async Task Put_UsesVNextSavePolicyCommand( Guid organizationId, PolicyType policyType, PolicyUpdateRequestModel model, @@ -33,9 +30,6 @@ public class PoliciesControllerTests policy.Data = null; sutProvider.GetDependency() .OrganizationId.Returns(organizationId); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(true); sutProvider.GetDependency() .SaveAsync(Arg.Any()) .Returns(policy); @@ -52,36 +46,4 @@ public class PoliciesControllerTests m.PolicyUpdate.Enabled == model.Enabled.GetValueOrDefault() && m.PerformedBy is SystemUser)); } - - [Theory] - [BitAutoData] - public async Task Put_WhenPolicyValidatorsRefactorDisabled_UsesLegacySavePolicyCommand( - Guid organizationId, - PolicyType policyType, - PolicyUpdateRequestModel model, - Policy policy, - SutProvider sutProvider) - { - // Arrange - policy.Data = null; - sutProvider.GetDependency() - .OrganizationId.Returns(organizationId); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(false); - sutProvider.GetDependency() - .SaveAsync(Arg.Any()) - .Returns(policy); - - // Act - await sutProvider.Sut.Put(policyType, model); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .SaveAsync(Arg.Is(p => - p.OrganizationId == organizationId && - p.Type == policyType && - p.Enabled == model.Enabled)); - } } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index f1aa11d068..300a4d823d 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -11,6 +11,7 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; @@ -38,6 +39,7 @@ public class AccountsControllerTests : IDisposable private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; + private readonly IUserRepository _userRepository; public AccountsControllerTests() { @@ -53,6 +55,7 @@ public class AccountsControllerTests : IDisposable _userAccountKeysQuery = Substitute.For(); _twoFactorEmailService = Substitute.For(); _changeKdfCommand = Substitute.For(); + _userRepository = Substitute.For(); _sut = new AccountsController( _organizationService, @@ -66,7 +69,8 @@ public class AccountsControllerTests : IDisposable _featureService, _userAccountKeysQuery, _twoFactorEmailService, - _changeKdfCommand + _changeKdfCommand, + _userRepository ); } @@ -688,6 +692,37 @@ public class AccountsControllerTests : IDisposable await _sut.PostKdf(model); } + [Theory] + [BitAutoData] + public async Task PostKeys_NoUser_Errors(KeysRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(null)); + + await Assert.ThrowsAsync(() => _sut.PostKeys(model)); + } + + [Theory] + [BitAutoData("existing", "existing")] + [BitAutoData((string)null, "existing")] + [BitAutoData("", "existing")] + [BitAutoData(" ", "existing")] + [BitAutoData("existing", null)] + [BitAutoData("existing", "")] + [BitAutoData("existing", " ")] + public async Task PostKeys_UserAlreadyHasKeys_Errors(string? existingPrivateKey, string? existingPublicKey, + KeysRequestModel model) + { + var user = GenerateExampleUser(); + user.PrivateKey = existingPrivateKey; + user.PublicKey = existingPublicKey; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + + var exception = await Assert.ThrowsAsync(() => _sut.PostKeys(model)); + + Assert.NotNull(exception.Message); + Assert.Contains("User has existing keypair", exception.Message); + } + // Below are helper functions that currently belong to this // test class, but ultimately may need to be split out into // something greater in order to share common test steps with @@ -738,5 +773,77 @@ public class AccountsControllerTests : IDisposable _userService.GetUserByIdAsync(Arg.Any()) .Returns(Task.FromResult((User)null)); } + + [Theory, BitAutoData] + public async Task PostKeys_WithAccountKeys_CallsSetV2AccountCryptographicState( + User user, + KeysRequestModel model) + { + // Arrange + user.PublicKey = null; + user.PrivateKey = null; + model.AccountKeys = new AccountKeysRequestModel + { + UserKeyEncryptedAccountPrivateKey = "wrapped-private-key", + AccountPublicKey = "public-key", + PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairRequestModel + { + PublicKey = "public-key", + WrappedPrivateKey = "wrapped-private-key", + SignedPublicKey = "signed-public-key" + }, + SignatureKeyPair = new SignatureKeyPairRequestModel + { + VerifyingKey = "verifying-key", + SignatureAlgorithm = "ed25519", + WrappedSigningKey = "wrapped-signing-key" + }, + SecurityState = new SecurityStateModel + { + SecurityState = "security-state", + SecurityVersion = 2 + } + }; + + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + // Act + var result = await _sut.PostKeys(model); + + // Assert + await _userRepository.Received(1).SetV2AccountCryptographicStateAsync( + user.Id, + Arg.Any()); + await _userService.DidNotReceiveWithAnyArgs().SaveUserAsync(Arg.Any()); + Assert.NotNull(result); + Assert.Equal("keys", result.Object); + } + + [Theory, BitAutoData] + public async Task PostKeys_WithoutAccountKeys_CallsSaveUser( + User user, + KeysRequestModel model) + { + // Arrange + user.PublicKey = null; + user.PrivateKey = null; + model.AccountKeys = null; + model.PublicKey = "public-key"; + model.EncryptedPrivateKey = "encrypted-private-key"; + + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + // Act + var result = await _sut.PostKeys(model); + + // Assert + await _userService.Received(1).SaveUserAsync(Arg.Is(u => + u.PublicKey == model.PublicKey && + u.PrivateKey == model.EncryptedPrivateKey)); + await _userRepository.DidNotReceiveWithAnyArgs() + .SetV2AccountCryptographicStateAsync(Arg.Any(), Arg.Any()); + Assert.NotNull(result); + Assert.Equal("keys", result.Object); + } } diff --git a/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs index d84fddd282..16b9b26436 100644 --- a/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs @@ -4,6 +4,7 @@ using Bit.Core; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.KeyManagement.Queries.Interfaces; @@ -27,9 +28,10 @@ public class AccountsControllerTests : IDisposable private readonly IUserService _userService; private readonly IFeatureService _featureService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IUserAccountKeysQuery _userAccountKeysQuery; + private readonly ILicensingService _licensingService; private readonly GlobalSettings _globalSettings; private readonly AccountsController _sut; @@ -37,16 +39,18 @@ public class AccountsControllerTests : IDisposable { _userService = Substitute.For(); _featureService = Substitute.For(); - _paymentService = Substitute.For(); + _paymentService = Substitute.For(); _twoFactorIsEnabledQuery = Substitute.For(); _userAccountKeysQuery = Substitute.For(); + _licensingService = Substitute.For(); _globalSettings = new GlobalSettings { SelfHosted = false }; _sut = new AccountsController( _userService, _twoFactorIsEnabledQuery, _userAccountKeysQuery, - _featureService + _featureService, + _licensingService ); } diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs index d79bfde893..ee0bdc61e4 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs @@ -3,9 +3,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http.HttpResults; @@ -103,7 +103,7 @@ public class OrganizationBillingControllerTests // Manually create a BillingHistoryInfo object to avoid requiring AutoFixture to create HttpResponseHeaders var billingInfo = new BillingHistoryInfo(); - sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); + sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); // Act var result = await sutProvider.Sut.GetHistoryAsync(organizationId); diff --git a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs index 2ad7686c30..87334dc085 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -24,11 +24,11 @@ namespace Bit.Api.Test.Billing.Controllers; public class OrganizationSponsorshipsControllerTests { public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonConfirmedOrganizationUsersStatuses => Enum.GetValues() diff --git a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs index a776bbea22..9a3f57c3dc 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs @@ -37,7 +37,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationService _organizationService; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IUserService _userService; private readonly IGetCloudOrganizationLicenseQuery _getCloudOrganizationLicenseQuery; @@ -59,7 +59,7 @@ public class OrganizationsControllerTests : IDisposable _organizationRepository = Substitute.For(); _organizationService = Substitute.For(); _organizationUserRepository = Substitute.For(); - _paymentService = Substitute.For(); + _paymentService = Substitute.For(); Substitute.For(); _ssoConfigRepository = Substitute.For(); Substitute.For(); diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index 75bd13eae8..652e82c801 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -1,5 +1,4 @@ using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -12,12 +11,10 @@ using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Models.Api; using Bit.Core.Models.BitStripe; -using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http; @@ -124,7 +121,7 @@ public class ProviderBillingControllerTests } }; - sutProvider.GetDependency().InvoiceListAsync(Arg.Is( + sutProvider.GetDependency().ListInvoicesAsync(Arg.Is( options => options.Customer == provider.GatewayCustomerId)).Returns(invoices); @@ -304,7 +301,7 @@ public class ProviderBillingControllerTests Status = "unpaid" }; - stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer.tax_ids") && options.Expand.Contains("discounts") && @@ -321,7 +318,7 @@ public class ProviderBillingControllerTests Attempted = true }; - stripeAdapter.InvoiceSearchAsync(Arg.Is( + stripeAdapter.SearchInvoiceAsync(Arg.Is( options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) .Returns([overdueInvoice]); @@ -351,10 +348,10 @@ public class ProviderBillingControllerTests foreach (var providerPlan in providerPlans) { - var plan = StaticStore.GetPlan(providerPlan.PlanType); + var plan = MockPlans.Get(providerPlan.PlanType); sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); - sutProvider.GetDependency().PriceGetAsync(priceId) + sutProvider.GetDependency().GetPriceAsync(priceId) .Returns(new Price { UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 @@ -372,7 +369,7 @@ public class ProviderBillingControllerTests Assert.Equal(subscription.Customer!.Discount!.Coupon!.PercentOff, response.DiscountPercentage); Assert.Equal(subscription.CollectionMethod, response.CollectionMethod); - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsPlan = MockPlans.Get(PlanType.TeamsMonthly); var providerTeamsPlan = response.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); Assert.NotNull(providerTeamsPlan); Assert.Equal(50, providerTeamsPlan.SeatMinimum); @@ -381,7 +378,7 @@ public class ProviderBillingControllerTests Assert.Equal(60 * teamsPlan.PasswordManager.ProviderPortalSeatPrice, providerTeamsPlan.Cost); Assert.Equal("Monthly", providerTeamsPlan.Cadence); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterprisePlan = MockPlans.Get(PlanType.EnterpriseMonthly); var providerEnterprisePlan = response.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); Assert.NotNull(providerEnterprisePlan); Assert.Equal(100, providerEnterprisePlan.SeatMinimum); @@ -462,13 +459,13 @@ public class ProviderBillingControllerTests Status = "active" }; - stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer.tax_ids") && options.Expand.Contains("discounts") && options.Expand.Contains("test_clock"))).Returns(subscription); - stripeAdapter.InvoiceSearchAsync(Arg.Is( + stripeAdapter.SearchInvoiceAsync(Arg.Is( options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) .Returns([]); @@ -498,10 +495,10 @@ public class ProviderBillingControllerTests foreach (var providerPlan in providerPlans) { - var plan = StaticStore.GetPlan(providerPlan.PlanType); + var plan = MockPlans.Get(providerPlan.PlanType); sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); - sutProvider.GetDependency().PriceGetAsync(priceId) + sutProvider.GetDependency().GetPriceAsync(priceId) .Returns(new Price { UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 @@ -521,49 +518,4 @@ public class ProviderBillingControllerTests } #endregion - - #region UpdateTaxInformationAsync - - [Theory, BitAutoData] - public async Task UpdateTaxInformation_NoCountry_BadRequest( - Provider provider, - TaxInformationRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableProviderAdminInputs(provider, sutProvider); - - requestBody.Country = null; - - var result = await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); - - Assert.IsType>(result); - - var response = (BadRequest)result; - - Assert.Equal("Country and postal code are required to update your tax information.", response.Value.Message); - } - - [Theory, BitAutoData] - public async Task UpdateTaxInformation_Ok( - Provider provider, - TaxInformationRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableProviderAdminInputs(provider, sutProvider); - - await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); - - await sutProvider.GetDependency().Received(1).UpdateTaxInformation( - provider, Arg.Is( - options => - options.Country == requestBody.Country && - options.PostalCode == requestBody.PostalCode && - options.TaxId == requestBody.TaxId && - options.Line1 == requestBody.Line1 && - options.Line2 == requestBody.Line2 && - options.City == requestBody.City && - options.State == requestBody.State)); - } - - #endregion } diff --git a/test/Api.Test/Controllers/PoliciesControllerTests.cs b/test/Api.Test/Controllers/PoliciesControllerTests.cs index 89d6ddefdc..efb9f7aaa9 100644 --- a/test/Api.Test/Controllers/PoliciesControllerTests.cs +++ b/test/Api.Test/Controllers/PoliciesControllerTests.cs @@ -3,7 +3,6 @@ using System.Text.Json; using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request; using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; @@ -291,7 +290,7 @@ public class PoliciesControllerTests string token, string email, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -302,14 +301,15 @@ public class PoliciesControllerTests var decryptedToken = Substitute.For(); decryptedToken.Valid.Returns(false); - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); // Act & Assert await Assert.ThrowsAsync(() => @@ -325,7 +325,7 @@ public class PoliciesControllerTests string token, string email, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -338,14 +338,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); sutProvider.GetDependency() .GetByIdAsync(organizationUserId) @@ -366,7 +367,7 @@ public class PoliciesControllerTests string email, OrganizationUser orgUser, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -379,14 +380,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); orgUser.OrganizationId = Guid.Empty; @@ -409,7 +411,7 @@ public class PoliciesControllerTests string email, OrganizationUser orgUser, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -422,14 +424,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); orgUser.OrganizationId = orgId; sutProvider.GetDependency() @@ -463,7 +466,7 @@ public class PoliciesControllerTests [Theory] [BitAutoData] - public async Task PutVNext_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand( + public async Task PutVNext_UsesVNextSavePolicyCommand( SutProvider sutProvider, Guid orgId, SavePolicyRequest model, Policy policy, Guid userId) { @@ -478,10 +481,6 @@ public class PoliciesControllerTests .OrganizationOwner(orgId) .Returns(true); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(true); - sutProvider.GetDependency() .SaveAsync(Arg.Any()) .Returns(policy); @@ -492,12 +491,11 @@ public class PoliciesControllerTests // Assert await sutProvider.GetDependency() .Received(1) - .SaveAsync(Arg.Is( - m => m.PolicyUpdate.OrganizationId == orgId && - m.PolicyUpdate.Type == policy.Type && - m.PolicyUpdate.Enabled == model.Policy.Enabled && - m.PerformedBy.UserId == userId && - m.PerformedBy.IsOrganizationOwnerOrProvider == true)); + .SaveAsync(Arg.Is(m => m.PolicyUpdate.OrganizationId == orgId && + m.PolicyUpdate.Type == policy.Type && + m.PolicyUpdate.Enabled == model.Policy.Enabled && + m.PerformedBy.UserId == userId && + m.PerformedBy.IsOrganizationOwnerOrProvider == true)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() @@ -507,51 +505,4 @@ public class PoliciesControllerTests Assert.Equal(policy.Id, result.Id); Assert.Equal(policy.Type, result.Type); } - - [Theory] - [BitAutoData] - public async Task PutVNext_WhenPolicyValidatorsRefactorDisabled_UsesSavePolicyCommand( - SutProvider sutProvider, Guid orgId, - SavePolicyRequest model, Policy policy, Guid userId) - { - // Arrange - policy.Data = null; - - sutProvider.GetDependency() - .UserId - .Returns(userId); - - sutProvider.GetDependency() - .OrganizationOwner(orgId) - .Returns(true); - - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(false); - - sutProvider.GetDependency() - .VNextSaveAsync(Arg.Any()) - .Returns(policy); - - // Act - var result = await sutProvider.Sut.PutVNext(orgId, policy.Type, model); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .VNextSaveAsync(Arg.Is( - m => m.PolicyUpdate.OrganizationId == orgId && - m.PolicyUpdate.Type == policy.Type && - m.PolicyUpdate.Enabled == model.Policy.Enabled && - m.PerformedBy.UserId == userId && - m.PerformedBy.IsOrganizationOwnerOrProvider == true)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SaveAsync(default); - - Assert.NotNull(result); - Assert.Equal(policy.Id, result.Id); - Assert.Equal(policy.Type, result.Type); - } } diff --git a/test/Api.Test/Dirt/HibpControllerTests.cs b/test/Api.Test/Dirt/HibpControllerTests.cs new file mode 100644 index 0000000000..9be8d56eae --- /dev/null +++ b/test/Api.Test/Dirt/HibpControllerTests.cs @@ -0,0 +1,292 @@ +using System.Net; +using System.Reflection; +using Bit.Api.Dirt.Controllers; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Api.Test.Dirt; + +[ControllerCustomize(typeof(HibpController))] +[SutProviderCustomize] +public class HibpControllerTests : IDisposable +{ + private readonly HttpClient _originalHttpClient; + private readonly FieldInfo _httpClientField; + + public HibpControllerTests() + { + // Store original HttpClient for restoration + _httpClientField = typeof(HibpController).GetField("_httpClient", BindingFlags.Static | BindingFlags.NonPublic); + _originalHttpClient = (HttpClient)_httpClientField?.GetValue(null); + } + + public void Dispose() + { + // Restore original HttpClient after tests + _httpClientField?.SetValue(null, _originalHttpClient); + } + + [Theory, BitAutoData] + public async Task Get_WithMissingApiKey_ThrowsBadRequestException( + SutProvider sutProvider, + string username) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = null; + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Equal("HaveIBeenPwned API key not set.", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_WithValidApiKeyAndNoBreaches_Returns200WithEmptyArray( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + var user = new User { Id = userId }; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + // Mock HttpClient to return 404 (no breaches found) + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.NotFound, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + var contentResult = Assert.IsType(result); + Assert.Equal("[]", contentResult.Content); + Assert.Equal("application/json", contentResult.ContentType); + } + + [Theory, BitAutoData] + public async Task Get_WithValidApiKeyAndBreachesFound_Returns200WithBreachData( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var breachData = "[{\"Name\":\"Adobe\",\"Title\":\"Adobe\",\"Domain\":\"adobe.com\"}]"; + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.OK, breachData); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + var contentResult = Assert.IsType(result); + Assert.Equal(breachData, contentResult.Content); + Assert.Equal("application/json", contentResult.ContentType); + } + + [Theory, BitAutoData] + public async Task Get_WithRateLimiting_RetriesWithDelay( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + // First response is rate limited, second is success + var requestCount = 0; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + requestCount++; + if (requestCount == 1) + { + var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests); + response.Headers.Add("retry-after", "1"); + return Task.FromResult(response); + } + else + { + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + } + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + Assert.Equal(2, requestCount); // Verify retry happened + var contentResult = Assert.IsType(result); + Assert.Equal("[]", contentResult.Content); + } + + [Theory, BitAutoData] + public async Task Get_WithServerError_ThrowsBadRequestException( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.InternalServerError, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Contains("Request failed. Status code:", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_WithBadRequest_ThrowsBadRequestException( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.BadRequest, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Contains("Request failed. Status code:", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_EncodesUsernameCorrectly( + SutProvider sutProvider, + Guid userId) + { + // Arrange + var usernameWithSpecialChars = "test+user@example.com"; + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + string capturedUrl = null; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + capturedUrl = request.RequestUri.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + await sutProvider.Sut.Get(usernameWithSpecialChars); + + // Assert + Assert.NotNull(capturedUrl); + // Username should be URL encoded (+ becomes %2B, @ becomes %40) + Assert.Contains("test%2Buser%40example.com", capturedUrl); + } + + [Theory, BitAutoData] + public async Task SendAsync_IncludesRequiredHeaders( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency().SelfHosted = false; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + HttpRequestMessage capturedRequest = null; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + capturedRequest = request; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + await sutProvider.Sut.Get(username); + + // Assert + Assert.NotNull(capturedRequest); + Assert.True(capturedRequest.Headers.Contains("hibp-api-key")); + Assert.True(capturedRequest.Headers.Contains("hibp-client-id")); + Assert.True(capturedRequest.Headers.Contains("User-Agent")); + Assert.Equal("Bitwarden", capturedRequest.Headers.GetValues("User-Agent").First()); + } + + /// + /// Helper to create a mock HttpClient that returns a specific status code and content + /// + private HttpClient CreateMockHttpClient(HttpStatusCode statusCode, string content) + { + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + return Task.FromResult(new HttpResponseMessage(statusCode) + { + Content = new StringContent(content) + }); + }); + + return new HttpClient(mockHandler); + } +} + +/// +/// Mock HttpMessageHandler for testing HttpClient behavior +/// +public class MockHttpMessageHandler : HttpMessageHandler +{ + private readonly Func> _sendAsync; + + public MockHttpMessageHandler(Func> sendAsync) + { + _sendAsync = sendAsync; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return _sendAsync(request, cancellationToken); + } +} + diff --git a/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index 2e41dd79a0..a1f3088f52 100644 --- a/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -14,7 +14,9 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Commands.Interfaces; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Services; @@ -362,4 +364,39 @@ public class AccountsKeyManagementControllerTests await sutProvider.GetDependency().Received(1) .ConvertToKeyConnectorAsync(Arg.Is(expectedUser)); } + + [Theory] + [BitAutoData] + public async Task GetKeyConnectorConfirmationDetailsAsync_NoUser_Throws( + SutProvider sutProvider, string orgSsoIdentifier) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .ReturnsNull(); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetKeyConnectorConfirmationDetailsAsync(orgSsoIdentifier)); + + await sutProvider.GetDependency().ReceivedWithAnyArgs(0) + .Run(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task GetKeyConnectorConfirmationDetailsAsync_Success( + SutProvider sutProvider, User expectedUser, string orgSsoIdentifier) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .Returns(expectedUser); + sutProvider.GetDependency().Run(orgSsoIdentifier, expectedUser.Id) + .Returns( + new KeyConnectorConfirmationDetails { OrganizationName = "test" } + ); + + var result = await sutProvider.Sut.GetKeyConnectorConfirmationDetailsAsync(orgSsoIdentifier); + + Assert.NotNull(result); + Assert.Equal("test", result.OrganizationName); + await sutProvider.GetDependency().Received(1) + .Run(orgSsoIdentifier, expectedUser.Id); + } } diff --git a/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs b/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs index 704371eebd..e1e97efce2 100644 --- a/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs +++ b/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs @@ -1,6 +1,6 @@ #nullable enable -using Bit.Api.KeyManagement.Models.Requests; +using Bit.Core.KeyManagement.Models.Api.Request; using Xunit; namespace Bit.Api.Test.KeyManagement.Models.Request; diff --git a/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs new file mode 100644 index 0000000000..79a339fcba --- /dev/null +++ b/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs @@ -0,0 +1,307 @@ +using Bit.Api.SecretsManager.Controllers; +using Bit.Api.SecretsManager.Models.Request; +using Bit.Core.Auth.Identity; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Entities; +using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.SecretsManager.AutoFixture.SecretsFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.SecretsManager.Controllers; + +[ControllerCustomize(typeof(SecretVersionsController))] +[SutProviderCustomize] +[SecretCustomize] +public class SecretVersionsControllerTests +{ + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_SecretNotFound_Throws( + SutProvider sutProvider, + Guid secretId) + { + sutProvider.GetDependency().GetByIdAsync(secretId).Returns((Secret?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secretId)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_NoAccess_Throws( + SutProvider sutProvider, + Secret secret) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(false); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_NoReadAccess_Throws( + SutProvider sutProvider, + Secret secret, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((false, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_Success( + SutProvider sutProvider, + Secret secret, + List versions, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + foreach (var version in versions) + { + version.SecretId = secret.Id; + } + sutProvider.GetDependency().GetManyBySecretIdAsync(secret.Id).Returns(versions); + + var result = await sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id); + + Assert.Equal(versions.Count, result.Data.Count()); + await sutProvider.GetDependency().Received(1) + .GetManyBySecretIdAsync(Arg.Is(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetById_VersionNotFound_Throws( + SutProvider sutProvider, + Guid versionId) + { + sutProvider.GetDependency().GetByIdAsync(versionId).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetByIdAsync(versionId)); + } + + [Theory] + [BitAutoData] + public async Task GetById_Success( + SutProvider sutProvider, + SecretVersion version, + Secret secret, + Guid userId) + { + version.SecretId = secret.Id; + sutProvider.GetDependency().GetByIdAsync(version.Id).Returns(version); + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + var result = await sutProvider.Sut.GetByIdAsync(version.Id); + + Assert.Equal(version.Id, result.Id); + Assert.Equal(version.SecretId, result.SecretId); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_NoWriteAccess_Throws( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId) + { + version.SecretId = secret.Id; + request.VersionId = version.Id; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_VersionNotFound_Throws( + SutProvider sutProvider, + Secret secret, + RestoreSecretVersionRequestModel request, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_VersionBelongsToDifferentSecret_Throws( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId) + { + version.SecretId = Guid.NewGuid(); // Different secret + request.VersionId = version.Id; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns(version); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_Success( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId, + OrganizationUser organizationUser) + { + version.SecretId = secret.Id; + request.VersionId = version.Id; + var versionValue = version.Value; + organizationUser.OrganizationId = secret.OrganizationId; + organizationUser.UserId = userId; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns(version); + sutProvider.GetDependency() + .GetByOrganizationAsync(secret.OrganizationId, userId).Returns(organizationUser); + sutProvider.GetDependency().UpdateAsync(Arg.Any()).Returns(x => x.Arg()); + + var result = await sutProvider.Sut.RestoreVersionAsync(secret.Id, request); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(Arg.Is(s => s.Value == versionValue)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_EmptyIds_Throws( + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(new List())); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_VersionNotFound_Throws( + SutProvider sutProvider, + List ids) + { + sutProvider.GetDependency().GetByIdAsync(ids[0]).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(ids)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_NoWriteAccess_Throws( + SutProvider sutProvider, + List versions, + Secret secret, + Guid userId) + { + var ids = versions.Select(v => v.Id).ToList(); + foreach (var version in versions) + { + version.SecretId = secret.Id; + sutProvider.GetDependency().GetByIdAsync(version.Id).Returns(version); + } + + sutProvider.GetDependency().GetManyByIds(Arg.Any>()) + .Returns(new List { secret }); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(ids)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_Success( + SutProvider sutProvider, + List versions, + Secret secret, + Guid userId) + { + var ids = versions.Select(v => v.Id).ToList(); + foreach (var version in versions) + { + version.SecretId = secret.Id; + } + + sutProvider.GetDependency().GetManyByIdsAsync(ids).Returns(versions); + sutProvider.GetDependency().GetManyByIds(Arg.Any>()) + .Returns(new List { secret }); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + + await sutProvider.Sut.BulkDeleteAsync(ids); + + await sutProvider.GetDependency().Received(1) + .DeleteManyByIdAsync(Arg.Is>(x => x.SequenceEqual(ids))); + } +} diff --git a/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs index 83a4229f39..51f61ad7c1 100644 --- a/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs +++ b/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs @@ -2,6 +2,7 @@ using Bit.Api.SecretsManager.Controllers; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.Test.SecretsManager.Enums; +using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -244,6 +245,7 @@ public class SecretsControllerTests { data = SetupSecretUpdateRequest(data); SetControllerUser(sutProvider, new Guid()); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); sutProvider.GetDependency() .AuthorizeAsync(Arg.Any(), Arg.Any(), Arg.Any>()).ReturnsForAnyArgs(AuthorizationResult.Success()); @@ -602,6 +604,7 @@ public class SecretsControllerTests { data = SetupSecretUpdateRequest(data, true); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); sutProvider.GetDependency() .AuthorizeAsync(Arg.Any(), Arg.Any(), Arg.Any>()).Returns(AuthorizationResult.Success()); diff --git a/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs index 78224a8bd8..5d3b7f2fa5 100644 --- a/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs +++ b/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs @@ -16,7 +16,7 @@ using Bit.Core.SecretsManager.Models.Data; using Bit.Core.SecretsManager.Queries.ServiceAccounts.Interfaces; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -121,7 +121,7 @@ public class ServiceAccountsControllerTests { ArrangeCreateServiceAccountAutoScalingTest(newSlotsRequired, sutProvider, data, organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); await sutProvider.Sut.CreateAsync(organization.Id, data); diff --git a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs index 9f54cdbea5..238053464c 100644 --- a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs @@ -79,7 +79,7 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(id, userId).ReturnsForAnyArgs(cipherDetails); sutProvider.GetDependency().GetManyByUserIdCipherIdAsync(userId, id).Returns((ICollection)new List()); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility() } }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility { Id = cipherDetails.OrganizationId.Value } } }); var cipherService = sutProvider.GetDependency(); await sutProvider.Sut.PutCollections_vNext(id, model); @@ -95,7 +95,7 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(id, userId).ReturnsForAnyArgs(cipherDetails); sutProvider.GetDependency().GetManyByUserIdCipherIdAsync(userId, id).Returns((ICollection)new List()); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility() } }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility { Id = cipherDetails.OrganizationId.Value } } }); var result = await sutProvider.Sut.PutCollections_vNext(id, model); @@ -1790,118 +1790,6 @@ public class CiphersControllerTests ); } - [Theory, BitAutoData] - public async Task PutShareMany_ArchivedCipher_ThrowsBadRequestException( - Guid organizationId, - Guid userId, - CipherWithIdRequestModel request, - SutProvider sutProvider) - { - request.EncryptedFor = userId; - request.OrganizationId = organizationId.ToString(); - request.ArchivedDate = DateTime.UtcNow; - var model = new CipherBulkShareRequestModel - { - Ciphers = [request], - CollectionIds = [Guid.NewGuid().ToString()] - }; - - sutProvider.GetDependency() - .OrganizationUser(organizationId) - .Returns(Task.FromResult(true)); - sutProvider.GetDependency() - .GetProperUserId(default) - .ReturnsForAnyArgs(userId); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PutShareMany(model) - ); - - Assert.Equal("Cannot move archived items to an organization.", exception.Message); - } - - [Theory, BitAutoData] - public async Task PutShareMany_ExistingCipherArchived_ThrowsBadRequestException( - Guid organizationId, - Guid userId, - CipherWithIdRequestModel request, - SutProvider sutProvider) - { - // Request model does not have ArchivedDate (only the existing cipher does) - request.EncryptedFor = userId; - request.OrganizationId = organizationId.ToString(); - request.ArchivedDate = null; - - var model = new CipherBulkShareRequestModel - { - Ciphers = [request], - CollectionIds = [Guid.NewGuid().ToString()] - }; - - // The existing cipher from the repository IS archived - var existingCipher = new CipherDetails - { - Id = request.Id!.Value, - UserId = userId, - Type = CipherType.Login, - Data = JsonSerializer.Serialize(new CipherLoginData()), - ArchivedDate = DateTime.UtcNow - }; - - sutProvider.GetDependency() - .OrganizationUser(organizationId) - .Returns(Task.FromResult(true)); - sutProvider.GetDependency() - .GetProperUserId(default) - .ReturnsForAnyArgs(userId); - sutProvider.GetDependency() - .GetManyByUserIdAsync(userId, withOrganizations: false) - .Returns(Task.FromResult((ICollection)[existingCipher])); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PutShareMany(model) - ); - - Assert.Equal("Cannot move archived items to an organization.", exception.Message); - } - - [Theory, BitAutoData] - public async Task PutShare_ArchivedCipher_ThrowsBadRequestException( - Guid cipherId, - Guid organizationId, - User user, - CipherShareRequestModel model, - SutProvider sutProvider) - { - model.Cipher.OrganizationId = organizationId.ToString(); - model.Cipher.EncryptedFor = user.Id; - - var cipher = new Cipher - { - Id = cipherId, - UserId = user.Id, - ArchivedDate = DateTime.UtcNow.AddDays(-1), - Type = CipherType.Login, - Data = JsonSerializer.Serialize(new CipherLoginData()) - }; - - sutProvider.GetDependency() - .GetUserByPrincipalAsync(Arg.Any()) - .Returns(user); - sutProvider.GetDependency() - .GetByIdAsync(cipherId) - .Returns(cipher); - sutProvider.GetDependency() - .OrganizationUser(organizationId) - .Returns(Task.FromResult(true)); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PutShare(cipherId, model) - ); - - Assert.Equal("Cannot move an archived item to an organization.", exception.Message); - } - [Theory, BitAutoData] public async Task PostPurge_WhenUserNotFound_ThrowsUnauthorizedAccessException( SecretVerificationRequestModel model, @@ -2021,4 +1909,237 @@ public class CiphersControllerTests await Assert.ThrowsAsync(() => sutProvider.Sut.PostPurge(model, organizationId)); } + + [Theory, BitAutoData] + public async Task PutShare_WithNullFolderAndFalseFavorite_UpdatesFieldsCorrectly( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid folderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, folderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }) + }; + + // Clears folder and favorite when sharing + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = null, + Favorite = false, + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = null, + Favorite = false + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Null(result.FolderId); + Assert.False(result.Favorite); + } + + [Theory, BitAutoData] + public async Task PutShare_WithFolderAndFavoriteSet_AddsUserSpecificFields( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid folderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = null, + Favorites = null + }; + + // Sets folder and favorite when sharing + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = folderId.ToString(), + Favorite = true, + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, folderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }), + FolderId = folderId, + Favorite = true + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Equal(folderId, result.FolderId); + Assert.True(result.Favorite); + } + + [Theory, BitAutoData] + public async Task PutShare_UpdateExistingFolderAndFavorite_UpdatesUserSpecificFields( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid oldFolderId, + Guid newFolderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + // Existing cipher with old folder and not favorited + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, oldFolderId.ToString().ToUpperInvariant() } }), + Favorites = null + }; + + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = newFolderId.ToString(), // Update to new folder + Favorite = true, // Add favorite + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, newFolderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }), + FolderId = newFolderId, + Favorite = true + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Equal(newFolderId, result.FolderId); + Assert.True(result.Favorite); + } } diff --git a/test/Api.Test/Vault/Controllers/SyncControllerTests.cs b/test/Api.Test/Vault/Controllers/SyncControllerTests.cs index a46eba283d..e6d34592c7 100644 --- a/test/Api.Test/Vault/Controllers/SyncControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/SyncControllerTests.cs @@ -18,9 +18,9 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tools.Entities; using Bit.Core.Tools.Repositories; -using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; @@ -335,7 +335,7 @@ public class SyncControllerTests if (matchedProviderUserOrgDetails != null) { - var providerOrgProductType = StaticStore.GetPlan(matchedProviderUserOrgDetails.PlanType).ProductTier; + var providerOrgProductType = MockPlans.Get(matchedProviderUserOrgDetails.PlanType).ProductTier; Assert.Equal(providerOrgProductType, profProviderOrg.ProductTierType); } } diff --git a/test/Billing.Test/Billing.Test.csproj b/test/Billing.Test/Billing.Test.csproj index 4d7f887c90..87a1c28ca1 100644 --- a/test/Billing.Test/Billing.Test.csproj +++ b/test/Billing.Test/Billing.Test.csproj @@ -5,8 +5,8 @@ - + @@ -24,6 +24,7 @@ + diff --git a/test/Billing.Test/Controllers/BitPayControllerTests.cs b/test/Billing.Test/Controllers/BitPayControllerTests.cs index d2d1c5b571..0118009cb7 100644 --- a/test/Billing.Test/Controllers/BitPayControllerTests.cs +++ b/test/Billing.Test/Controllers/BitPayControllerTests.cs @@ -31,7 +31,7 @@ public class BitPayControllerTests private readonly IUserRepository _userRepository = Substitute.For(); private readonly IProviderRepository _providerRepository = Substitute.For(); private readonly IMailService _mailService = Substitute.For(); - private readonly IPaymentService _paymentService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); private readonly IPremiumUserBillingService _premiumUserBillingService = Substitute.For(); diff --git a/test/Billing.Test/Controllers/PayPalControllerTests.cs b/test/Billing.Test/Controllers/PayPalControllerTests.cs index 7ec17bd85a..da995b6188 100644 --- a/test/Billing.Test/Controllers/PayPalControllerTests.cs +++ b/test/Billing.Test/Controllers/PayPalControllerTests.cs @@ -8,13 +8,13 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; -using Divergic.Logging.Xunit; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Neovolve.Logging.Xunit; using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; @@ -23,14 +23,12 @@ using Transaction = Bit.Core.Entities.Transaction; namespace Bit.Billing.Test.Controllers; -public class PayPalControllerTests +public class PayPalControllerTests(ITestOutputHelper testOutputHelper) { - private readonly ITestOutputHelper _testOutputHelper; - private readonly IOptions _billingSettings = Substitute.For>(); private readonly IMailService _mailService = Substitute.For(); private readonly IOrganizationRepository _organizationRepository = Substitute.For(); - private readonly IPaymentService _paymentService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); private readonly ITransactionRepository _transactionRepository = Substitute.For(); private readonly IUserRepository _userRepository = Substitute.For(); private readonly IProviderRepository _providerRepository = Substitute.For(); @@ -38,15 +36,10 @@ public class PayPalControllerTests private const string _defaultWebhookKey = "webhook-key"; - public PayPalControllerTests(ITestOutputHelper testOutputHelper) - { - _testOutputHelper = testOutputHelper; - } - [Fact] public async Task PostIpn_NullKey_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); var controller = ConfigureControllerContextWith(logger, null, null); @@ -60,7 +53,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_IncorrectKey_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -79,7 +72,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_EmptyIPNBody_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -98,7 +91,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_IPNHasNoEntityId_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -119,15 +112,13 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_OtherTransactionType_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { PayPal = { WebhookKey = _defaultWebhookKey } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.UnsupportedTransactionType); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -142,7 +133,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_MismatchedReceiverID_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -153,8 +144,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -169,7 +158,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_RefundMissingParent_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -180,8 +169,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.RefundMissingParentTransaction); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -196,7 +183,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_eCheckPayment_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -207,8 +194,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.ECheckPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -223,7 +208,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_NonUSD_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -234,8 +219,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.NonUSDPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -250,7 +233,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_ExistingTransaction_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -261,8 +244,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); _transactionRepository.GetByGatewayIdAsync( @@ -281,7 +262,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -292,8 +273,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); _transactionRepository.GetByGatewayIdAsync( @@ -314,7 +293,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_CreditsOrganizationAccount_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -362,7 +341,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_CreditsUserAccount_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -406,7 +385,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_ExistingTransaction_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -417,8 +396,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulRefund); _transactionRepository.GetByGatewayIdAsync( @@ -441,7 +418,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_MissingParentTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -452,8 +429,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulRefund); _transactionRepository.GetByGatewayIdAsync( @@ -480,7 +455,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_ReplacesParent_CreatesTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -531,8 +506,8 @@ public class PayPalControllerTests private PayPalController ConfigureControllerContextWith( ILogger logger, - string webhookKey, - string ipnBody) + string? webhookKey, + string? ipnBody) { var controller = new PayPalController( _billingSettings, @@ -578,16 +553,16 @@ public class PayPalControllerTests Assert.Equal(statusCode, statusCodeActionResult.StatusCode); } - private static void Logged(ICacheLogger logger, LogLevel logLevel, string message) + private static void Logged(ICacheLogger logger, LogLevel logLevel, string message) { Assert.NotNull(logger.Last); Assert.Equal(logLevel, logger.Last!.LogLevel); Assert.Equal(message, logger.Last!.Message); } - private static void LoggedError(ICacheLogger logger, string message) + private static void LoggedError(ICacheLogger logger, string message) => Logged(logger, LogLevel.Error, message); - private static void LoggedWarning(ICacheLogger logger, string message) + private static void LoggedWarning(ICacheLogger logger, string message) => Logged(logger, LogLevel.Warning, message); } diff --git a/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs b/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs new file mode 100644 index 0000000000..b3540246b0 --- /dev/null +++ b/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs @@ -0,0 +1,789 @@ +using Bit.Billing.Jobs; +using Bit.Billing.Services; +using Bit.Core; +using Bit.Core.Billing.Constants; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Quartz; +using Stripe; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class ReconcileAdditionalStorageJobTests +{ + private readonly IStripeFacade _stripeFacade; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; + private readonly ReconcileAdditionalStorageJob _sut; + + public ReconcileAdditionalStorageJobTests() + { + _stripeFacade = Substitute.For(); + _logger = Substitute.For>(); + _featureService = Substitute.For(); + _sut = new ReconcileAdditionalStorageJob(_stripeFacade, _logger, _featureService); + } + + #region Feature Flag Tests + + [Fact] + public async Task Execute_FeatureFlagDisabled_SkipsProcessing() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob) + .Returns(false); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.DidNotReceiveWithAnyArgs().ListSubscriptionsAutoPagingAsync(); + } + + [Fact] + public async Task Execute_FeatureFlagEnabled_ProcessesSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob) + .Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode) + .Returns(false); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Empty()); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.Received(3).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Limit == 100)); + } + + #endregion + + #region Dry Run Mode Tests + + [Fact] + public async Task Execute_DryRunMode_DoesNotUpdateSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(false); // Dry run ON + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_DryRunModeDisabled_UpdatesSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); // Dry run OFF + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => o.Items.Count == 1)); + } + + #endregion + + #region Price ID Processing Tests + + [Fact] + public async Task Execute_ProcessesAllThreePriceIds() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(false); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Empty()); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "storage-gb-monthly")); + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "storage-gb-annually")); + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "personal-storage-gb-annually")); + } + + #endregion + + #region Already Processed Tests + + [Fact] + public async Task Execute_SubscriptionAlreadyProcessed_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") + }; + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: metadata); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_SubscriptionWithInvalidProcessedDate_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = "invalid-date" + }; + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: metadata); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_SubscriptionWithoutMetadata_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: null); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + #endregion + + #region Quantity Reduction Logic Tests + + [Fact] + public async Task Execute_QuantityGreaterThan4_ReducesBy4() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Quantity == 6 && + o.Items[0].Deleted != true)); + } + + [Fact] + public async Task Execute_QuantityEquals4_DeletesItem() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 4); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Deleted == true)); + } + + [Fact] + public async Task Execute_QuantityLessThan4_DeletesItem() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 2); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Deleted == true)); + } + + #endregion + + #region Update Options Tests + + [Fact] + public async Task Execute_UpdateOptions_SetsProrationBehaviorToCreateProrations() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => o.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)); + } + + [Fact] + public async Task Execute_UpdateOptions_SetsReconciledMetadata() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Metadata.ContainsKey(StripeConstants.MetadataKeys.StorageReconciled2025) && + !string.IsNullOrEmpty(o.Metadata[StripeConstants.MetadataKeys.StorageReconciled2025]))); + } + + #endregion + + #region Subscription Filtering Tests + + [Fact] + public async Task Execute_SubscriptionWithNoItems_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = new Subscription + { + Id = "sub_123", + Items = null + }; + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_SubscriptionWithDifferentPriceId_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "different-price-id", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_NullSubscription_SkipsProcessing() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(null!)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + #endregion + + #region Multiple Subscriptions Tests + + [Fact] + public async Task Execute_MultipleSubscriptions_ProcessesAll() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_1" => subscription1, + "sub_2" => subscription2, + "sub_3" => subscription3, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + [Fact] + public async Task Execute_MixedSubscriptionsWithProcessed_OnlyProcessesUnprocessed() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var processedMetadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") + }; + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5, metadata: processedMetadata); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_1" => subscription1, + "sub_3" => subscription3, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + #endregion + + #region Error Handling Tests + + [Fact] + public async Task Execute_UpdateFails_ContinuesProcessingOthers() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + + _stripeFacade.UpdateSubscription("sub_1", Arg.Any()) + .Returns(subscription1); + _stripeFacade.UpdateSubscription("sub_2", Arg.Any()) + .Throws(new Exception("Stripe API error")); + _stripeFacade.UpdateSubscription("sub_3", Arg.Any()) + .Returns(subscription3); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + [Fact] + public async Task Execute_UpdateFails_LogsError() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Throws(new Exception("Stripe API error")); + + // Act + await _sut.Execute(context); + + // Assert + _logger.Received().Log( + LogLevel.Error, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + #endregion + + #region Subscription Status Filtering Tests + + [Fact] + public async Task Execute_ActiveStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_TrialingStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Trialing); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_PastDueStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.PastDue); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_CanceledStatusSubscription_SkipsSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Canceled); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_IncompleteStatusSubscription_SkipsSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Incomplete); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_MixedSubscriptionStatuses_OnlyProcessesValidStatuses() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var activeSubscription = CreateSubscription("sub_active", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active); + var trialingSubscription = CreateSubscription("sub_trialing", "storage-gb-monthly", quantity: 8, status: StripeConstants.SubscriptionStatus.Trialing); + var pastDueSubscription = CreateSubscription("sub_pastdue", "storage-gb-monthly", quantity: 6, status: StripeConstants.SubscriptionStatus.PastDue); + var canceledSubscription = CreateSubscription("sub_canceled", "storage-gb-monthly", quantity: 5, status: StripeConstants.SubscriptionStatus.Canceled); + var incompleteSubscription = CreateSubscription("sub_incomplete", "storage-gb-monthly", quantity: 4, status: StripeConstants.SubscriptionStatus.Incomplete); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(activeSubscription, trialingSubscription, pastDueSubscription, canceledSubscription, incompleteSubscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_active" => activeSubscription, + "sub_trialing" => trialingSubscription, + "sub_pastdue" => pastDueSubscription, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_active", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_trialing", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_pastdue", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_canceled", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_incomplete", Arg.Any()); + } + + #endregion + + #region Cancellation Tests + + [Fact] + public async Task Execute_CancellationRequested_LogsWarningAndExits() + { + // Arrange + var cts = new CancellationTokenSource(); + cts.Cancel(); // Cancel immediately + var context = CreateJobExecutionContext(cts.Token); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1)); + + // Act + await _sut.Execute(context); + + // Assert - Should not process any subscriptions due to immediate cancellation + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null); + _logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + #endregion + + #region Helper Methods + + private static IJobExecutionContext CreateJobExecutionContext(CancellationToken cancellationToken = default) + { + var context = Substitute.For(); + context.CancellationToken.Returns(cancellationToken); + return context; + } + + private static Subscription CreateSubscription( + string id, + string priceId, + long? quantity = null, + Dictionary? metadata = null, + string status = StripeConstants.SubscriptionStatus.Active) + { + var price = new Price { Id = priceId }; + var item = new SubscriptionItem + { + Id = $"si_{id}", + Price = price, + Quantity = quantity ?? 0 + }; + + return new Subscription + { + Id = id, + Status = status, + Metadata = metadata, + Items = new StripeList + { + Data = new List { item } + } + }; + } + + #endregion +} + +internal static class AsyncEnumerable +{ + public static async IAsyncEnumerable Create(params T[] items) + { + foreach (var item in items) + { + yield return item; + } + await Task.CompletedTask; + } + + public static async IAsyncEnumerable Empty() + { + await Task.CompletedTask; + yield break; + } +} diff --git a/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs b/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs new file mode 100644 index 0000000000..03bf24f7ff --- /dev/null +++ b/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs @@ -0,0 +1,388 @@ +using Bit.Billing.Constants; +using Bit.Billing.Jobs; +using Bit.Billing.Services; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Repositories; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Quartz; +using Stripe; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class SubscriptionCancellationJobTests +{ + private readonly IStripeFacade _stripeFacade; + private readonly IOrganizationRepository _organizationRepository; + private readonly SubscriptionCancellationJob _sut; + + public SubscriptionCancellationJobTests() + { + _stripeFacade = Substitute.For(); + _organizationRepository = Substitute.For(); + _sut = new SubscriptionCancellationJob(_stripeFacade, _organizationRepository, Substitute.For>()); + } + + [Fact] + public async Task Execute_OrganizationIsNull_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + _organizationRepository.GetByIdAsync(organizationId).Returns((Organization)null); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().GetSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().CancelSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_OrganizationIsEnabled_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = true + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().GetSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().CancelSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_SubscriptionStatusIsNotUnpaid_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Active, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceive().CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_BillingReasonIsInvalid_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "manual" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceive().CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_ValidConditions_CancelsSubscriptionAndVoidsInvoices() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = + [ + new Invoice { Id = "inv_1" }, + new Invoice { Id = "inv_2" } + ], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1).VoidInvoice("inv_1"); + await _stripeFacade.Received(1).VoidInvoice("inv_2"); + } + + [Fact] + public async Task Execute_WithSubscriptionCreateBillingReason_CancelsSubscription() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_create" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_NoOpenInvoices_CancelsSubscriptionOnly() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().VoidInvoice(Arg.Any()); + } + + [Fact] + public async Task Execute_WithPagination_VoidsAllInvoices() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // First page of invoices + var firstPage = new StripeList + { + Data = + [ + new Invoice { Id = "inv_1" }, + new Invoice { Id = "inv_2" } + ], + HasMore = true + }; + + // Second page of invoices + var secondPage = new StripeList + { + Data = + [ + new Invoice { Id = "inv_3" }, + new Invoice { Id = "inv_4" } + ], + HasMore = false + }; + + _stripeFacade.ListInvoices(Arg.Is(o => o.StartingAfter == null)) + .Returns(firstPage); + _stripeFacade.ListInvoices(Arg.Is(o => o.StartingAfter == "inv_2")) + .Returns(secondPage); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1).VoidInvoice("inv_1"); + await _stripeFacade.Received(1).VoidInvoice("inv_2"); + await _stripeFacade.Received(1).VoidInvoice("inv_3"); + await _stripeFacade.Received(1).VoidInvoice("inv_4"); + await _stripeFacade.Received(2).ListInvoices(Arg.Any()); + } + + [Fact] + public async Task Execute_ListInvoicesCalledWithCorrectOptions() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))); + await _stripeFacade.Received(1).ListInvoices(Arg.Is(o => + o.Status == "open" && + o.Subscription == subscriptionId && + o.Limit == 100)); + } + + private static IJobExecutionContext CreateJobExecutionContext(string subscriptionId, Guid organizationId) + { + var context = Substitute.For(); + var jobDataMap = new JobDataMap + { + { "subscriptionId", subscriptionId }, + { "organizationId", organizationId.ToString() } + }; + context.MergedJobDataMap.Returns(jobDataMap); + return context; + } +} diff --git a/test/Billing.Test/Services/ProviderEventServiceTests.cs b/test/Billing.Test/Services/ProviderEventServiceTests.cs index d5f273fa65..34c69b95c2 100644 --- a/test/Billing.Test/Services/ProviderEventServiceTests.cs +++ b/test/Billing.Test/Services/ProviderEventServiceTests.cs @@ -9,7 +9,7 @@ using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Enums; using Bit.Core.Repositories; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using NSubstitute; using Stripe; using Xunit; @@ -237,7 +237,7 @@ public class ProviderEventServiceTests foreach (var providerPlan in providerPlans) { - _pricingClient.GetPlanOrThrow(providerPlan.PlanType).Returns(StaticStore.GetPlan(providerPlan.PlanType)); + _pricingClient.GetPlanOrThrow(providerPlan.PlanType).Returns(MockPlans.Get(providerPlan.PlanType)); } _providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); @@ -246,8 +246,8 @@ public class ProviderEventServiceTests await _providerEventService.TryRecordInvoiceLineItems(stripeEvent); // Assert - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterprisePlan = MockPlans.Get(PlanType.EnterpriseMonthly); await _providerInvoiceItemRepository.Received(1).CreateAsync(Arg.Is( options => diff --git a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs index e9f0d9d0ed..a7aefe3163 100644 --- a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs +++ b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs @@ -4,8 +4,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -61,7 +61,7 @@ public class SetupIntentSucceededHandlerTests // Assert await _setupIntentCache.DidNotReceiveWithAnyArgs().GetSubscriberIdForSetupIntent(Arg.Any()); - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -86,7 +86,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -116,7 +116,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == organization.GatewayCustomerId)); @@ -151,7 +151,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == provider.GatewayCustomerId)); @@ -183,7 +183,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -216,7 +216,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 16287bc5c9..182f09e163 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -1,18 +1,17 @@ using Bit.Billing.Constants; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; using NSubstitute; @@ -21,6 +20,8 @@ using Quartz; using Stripe; using Xunit; using Event = Stripe.Event; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; +using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; namespace Bit.Billing.Test.Services; @@ -126,79 +127,6 @@ public class SubscriptionUpdatedHandlerTests Arg.Is(t => t.Key.Name == $"cancel-trigger-{subscriptionId}")); } - [Fact] - public async Task - HandleAsync_UnpaidProviderSubscription_WithManualSuspensionViaMetadata_DisablesProviderAndSchedulesCancellation() - { - // Arrange - var providerId = Guid.NewGuid(); - var subscriptionId = "sub_test123"; - - var previousSubscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Active, - Metadata = new Dictionary - { - ["suspend_provider"] = null // This is the key part - metadata exists, but value is null - } - }; - - var currentSubscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, - Items = new StripeList - { - Data = - [ - new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } - ] - }, - Metadata = new Dictionary - { - ["providerId"] = providerId.ToString(), - ["suspend_provider"] = "true" // Now has a value, indicating manual suspension - }, - TestClock = null - }; - - var parsedEvent = new Event - { - Id = "evt_test123", - Type = HandledStripeWebhook.SubscriptionUpdated, - Data = new EventData - { - Object = currentSubscription, - PreviousAttributes = JObject.FromObject(previousSubscription) - } - }; - - var provider = new Provider { Id = providerId, Enabled = true }; - - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); - _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); - _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId).Returns(provider); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); - - // Verify that UpdateSubscription was called with both CancelAt and the new metadata - await _stripeFacade.Received(1).UpdateSubscription( - subscriptionId, - Arg.Is(options => - options.CancelAt.HasValue && - options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && - options.Metadata != null && - options.Metadata.ContainsKey("suspended_provider_via_webhook_at"))); - } - [Fact] public async Task HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() @@ -243,7 +171,6 @@ public class SubscriptionUpdatedHandlerTests var provider = new Provider { Id = providerId, Enabled = true }; - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) .Returns(Tuple.Create(null, null, providerId)); @@ -256,13 +183,12 @@ public class SubscriptionUpdatedHandlerTests Assert.False(provider.Enabled); await _providerService.Received(1).UpdateAsync(provider); - // Verify that UpdateSubscription was called with CancelAt but WITHOUT suspension metadata + // Verify that UpdateSubscription was called with CancelAt await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => options.CancelAt.HasValue && - options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && - (options.Metadata == null || !options.Metadata.ContainsKey("suspended_provider_via_webhook_at")))); + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1))); } [Fact] @@ -306,9 +232,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -353,9 +276,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -401,9 +321,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -416,48 +333,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } - [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WhenFeatureFlagDisabled_DoesNothing() - { - // Arrange - var providerId = Guid.NewGuid(); - var subscriptionId = "sub_123"; - var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - - var subscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, - Items = new StripeList - { - Data = - [ - new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } - ] - }, - Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } - }; - - var parsedEvent = new Event { Data = new EventData() }; - - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); - - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(false); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); - } - [Fact] public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_DoesNothing() { @@ -489,9 +364,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns((Provider)null); @@ -530,6 +402,75 @@ public class SubscriptionUpdatedHandlerTests var parsedEvent = new Event { Data = new EventData() }; + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List() }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + } + + [Fact] + public async Task HandleAsync_IncompleteExpiredUserSubscription_DisablesPremiumAndCancelsSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.IncompleteExpired, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); @@ -695,7 +636,7 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), - Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + Plan = new Stripe.Plan { Id = "2023-enterprise-org-seat-annually" } } ] }, @@ -729,7 +670,7 @@ public class SubscriptionUpdatedHandlerTests { Data = [ - new SubscriptionItem { Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } } + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-enterprise-seat-annually" } } ] } }) @@ -777,8 +718,6 @@ public class SubscriptionUpdatedHandlerTests _stripeFacade .UpdateSubscription(Arg.Any(), Arg.Any()) .Returns(newSubscription); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -800,9 +739,6 @@ public class SubscriptionUpdatedHandlerTests .Received(1) .UpdateSubscription(newSubscription.Id, Arg.Is(options => options.CancelAtPeriodEnd == false)); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -823,8 +759,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -843,9 +777,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -866,8 +797,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -886,9 +815,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -909,8 +835,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -929,9 +853,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -953,8 +874,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -975,9 +894,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -997,8 +913,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .ReturnsNull(); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -1019,9 +933,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceive() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -1040,8 +951,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -1062,9 +971,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceive() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } private static (Guid providerId, Subscription newSubscription, Provider provider, Event parsedEvent) @@ -1098,6 +1004,134 @@ public class SubscriptionUpdatedHandlerTests return (providerId, newSubscription, provider, parsedEvent); } + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithOpenInvoice_CancelsSubscriptionAndDisablesPremium() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var openInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Open + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = openInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List { openInvoice } }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + await _stripeFacade.Received(1) + .VoidInvoice(openInvoice.Id); + } + + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithoutOpenInvoice_DoesNotCancelSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var paidInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Paid + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = paidInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.DidNotReceive() + .DisablePremiumAsync(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .CancelSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .ListInvoices(Arg.Any()); + } + public static IEnumerable GetNonActiveSubscriptions() { return new List diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs index 5ac77eb42a..3b133c7d37 100644 --- a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -1,21 +1,24 @@ -using Bit.Billing.Services; +using System.Globalization; +using Bit.Billing.Services; using Bit.Billing.Services.Implementations; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing.Premium; using Bit.Core.Entities; -using Bit.Core.Models.Mail.UpdatedInvoiceIncoming; +using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Premium; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using NSubstitute; using NSubstitute.ExceptionExtensions; @@ -117,7 +120,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -126,10 +129,7 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { - new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } } - } + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, Customer = new Customer { Id = customerId }, @@ -199,7 +199,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -208,10 +208,7 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { - new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } } - } + Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, Customer = new Customer @@ -233,7 +230,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -257,6 +254,9 @@ public class UpcomingInvoiceHandlerTests .IsEnabled(FeatureFlagKeys.PM23341_Milestone_2) .Returns(true); + var coupon = new Coupon { PercentOff = 20, Id = CouponIDs.Milestone2SubscriptionDiscount }; + + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); // Act await _sut.HandleAsync(parsedEvent); @@ -264,6 +264,7 @@ public class UpcomingInvoiceHandlerTests // Assert await _userRepository.Received(1).GetByIdAsync(_userId); await _pricingClient.Received(1).GetAvailablePremiumPlan(); + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone2SubscriptionDiscount); await _stripeFacade.Received(1).UpdateSubscription( Arg.Is("sub_123"), Arg.Is(o => @@ -272,11 +273,16 @@ public class UpcomingInvoiceHandlerTests o.Discounts[0].Coupon == CouponIDs.Milestone2SubscriptionDiscount && o.ProrationBehavior == "none")); - // Verify the updated invoice email was sent + // Verify the updated invoice email was sent with correct price + var discountedPrice = plan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; await _mailer.Received(1).SendEmail( - Arg.Is(email => + Arg.Is(email => email.ToEmails.Contains("user@example.com") && - email.Subject == "Your Subscription Will Renew Soon")); + email.Subject == "Your Bitwarden Premium renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); } [Fact] @@ -291,7 +297,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -307,7 +313,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; var organization = new Organization @@ -375,7 +381,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -395,7 +401,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; var organization = new Organization @@ -469,7 +475,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -489,7 +495,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; var organization = new Organization @@ -560,7 +566,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -576,7 +582,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "UK" }, TaxExempt = TaxExempt.None }; @@ -622,9 +628,8 @@ public class UpcomingInvoiceHandlerTests } [Fact] - public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsEmail() + public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsTraditionalEmail() { - // Arrange // Arrange var parsedEvent = new Event { Id = "evt_123" }; var customerId = "cus_123"; @@ -637,7 +642,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -646,10 +651,7 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { - new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } } - } + Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Customer = new Customer @@ -671,7 +673,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -708,11 +710,16 @@ public class UpcomingInvoiceHandlerTests Arg.Any(), Arg.Any>()); - // Verify that email was still sent despite the exception - await _mailer.Received(1).SendEmail( - Arg.Is(email => - email.ToEmails.Contains("user@example.com") && - email.Subject == "Your Subscription Will Renew Soon")); + // Verify that traditional email was sent when update fails + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + + // Verify renewal email was NOT sent + await _mailer.DidNotReceive().SendEmail(Arg.Any()); } [Fact] @@ -727,7 +734,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -737,12 +744,12 @@ public class UpcomingInvoiceHandlerTests Items = new StripeList(), AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, Customer = new Customer { Id = "cus_123" }, - Metadata = new Dictionary(), + Metadata = new Dictionary() }; var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -784,7 +791,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Free Item" } } + Data = [new() { Description = "Free Item" }] } }; var subscription = new Subscription @@ -800,7 +807,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -841,7 +848,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -856,7 +863,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -885,7 +892,7 @@ public class UpcomingInvoiceHandlerTests Arg.Any>(), Arg.Any()); - await _mailer.DidNotReceive().SendEmail(Arg.Any()); + await _mailer.DidNotReceive().SendEmail(Arg.Any()); } [Fact] @@ -900,7 +907,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; var subscription = new Subscription @@ -915,7 +922,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = "cus_123", - Subscriptions = new StripeList { Data = new List { subscription } } + Subscriptions = new StripeList { Data = [subscription] } }; _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); @@ -964,7 +971,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -977,8 +984,8 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { + Data = + [ new() { Id = passwordManagerItemId, @@ -989,7 +996,7 @@ public class UpcomingInvoiceHandlerTests Id = premiumAccessItemId, Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId } } - } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -998,7 +1005,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1009,8 +1016,11 @@ public class UpcomingInvoiceHandlerTests PlanType = PlanType.FamiliesAnnually2019 }; + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); _stripeEventUtilityService .GetIdsFromMetadata(subscription.Metadata) .Returns(new Tuple(_organizationId, null, null)); @@ -1036,6 +1046,8 @@ public class UpcomingInvoiceHandlerTests o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && o.ProrationBehavior == ProrationBehavior.None)); + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + await _organizationRepository.Received(1).ReplaceAsync( Arg.Is(org => org.Id == _organizationId && @@ -1045,9 +1057,13 @@ public class UpcomingInvoiceHandlerTests org.Seats == familiesPlan.PasswordManager.BaseSeats)); await _mailer.Received(1).SendEmail( - Arg.Is(email => + Arg.Is(email => email.ToEmails.Contains("org@example.com") && - email.Subject == "Your Subscription Will Renew Soon")); + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); } [Fact] @@ -1066,7 +1082,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -1079,14 +1095,14 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { + Data = + [ new() { Id = passwordManagerItemId, Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } } - } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -1095,7 +1111,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1141,7 +1157,7 @@ public class UpcomingInvoiceHandlerTests } [Fact] - public async Task HandleAsync_WhenMilestone3Disabled_DoesNotUpdateSubscription() + public async Task HandleAsync_WhenMilestone3Disabled_AndFamilies2019Plan_DoesNotUpdateSubscription() { // Arrange var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; @@ -1156,7 +1172,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -1168,14 +1184,14 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { + Data = + [ new() { Id = passwordManagerItemId, Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } } - } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -1184,7 +1200,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1232,7 +1248,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -1244,14 +1260,10 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { - new() - { - Id = "si_pm_123", - Price = new Price { Id = familiesPlan.PasswordManager.StripePlanId } - } - } + Data = + [ + new() { Id = "si_pm_123", Price = new Price { Id = familiesPlan.PasswordManager.StripePlanId } } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -1260,7 +1272,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1307,7 +1319,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -1319,14 +1331,10 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { - new() - { - Id = "si_different_item", - Price = new Price { Id = "different-price-id" } - } - } + Data = + [ + new() { Id = "si_different_item", Price = new Price { Id = "different-price-id" } } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -1335,7 +1343,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1378,7 +1386,7 @@ public class UpcomingInvoiceHandlerTests } [Fact] - public async Task HandleAsync_WhenMilestone3Enabled_AndUpdateFails_LogsError() + public async Task HandleAsync_WhenMilestone3Enabled_AndUpdateFails_LogsErrorAndSendsTraditionalEmail() { // Arrange var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; @@ -1393,7 +1401,7 @@ public class UpcomingInvoiceHandlerTests NextPaymentAttempt = DateTime.UtcNow.AddDays(7), Lines = new StripeList { - Data = new List { new() { Description = "Test Item" } } + Data = [new() { Description = "Test Item" }] } }; @@ -1406,14 +1414,14 @@ public class UpcomingInvoiceHandlerTests CustomerId = customerId, Items = new StripeList { - Data = new List - { + Data = + [ new() { Id = passwordManagerItemId, Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } } - } + ] }, AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Metadata = new Dictionary() @@ -1422,7 +1430,7 @@ public class UpcomingInvoiceHandlerTests var customer = new Customer { Id = customerId, - Subscriptions = new StripeList { Data = new List { subscription } }, + Subscriptions = new StripeList { Data = [subscription] }, Address = new Address { Country = "US" } }; @@ -1463,10 +1471,1060 @@ public class UpcomingInvoiceHandlerTests Arg.Any(), Arg.Any>()); - // Should still attempt to send email despite the failure - await _mailer.Received(1).SendEmail( - Arg.Is(email => - email.ToEmails.Contains("org@example.com") && - email.Subject == "Your Subscription Will Renew Soon")); + // Should send traditional email when update fails + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + + // Verify renewal email was NOT sent + await _mailer.DidNotReceive().SendEmail(Arg.Any()); } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndCouponNotFound_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns((Coupon)null); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to align subscription concerns for Organization ({_organizationId})") && + o.ToString().Contains(parsedEvent.Type) && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException && e.Message.Contains("Coupon for sending families 2019 email")), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndCouponPercentOffIsNull_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon + { + Id = CouponIDs.Milestone3SubscriptionDiscount, + PercentOff = null + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to align subscription concerns for Organization ({_organizationId})") && + o.ToString().Contains(parsedEvent.Type) && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException && e.Message.Contains("coupon.PercentOff")), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndSeatAddOnExists_DeletesItem() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 3 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 2 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == seatAddOnItemId && + o.Items[1].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndSeatAddOnWithQuantityOne_DeletesItem() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 1 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 2 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == seatAddOnItemId && + o.Items[1].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_WithPremiumAccessAndSeatAddOn_UpdatesBothItems() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var premiumAccessItemId = "si_premium_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = premiumAccessItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 2 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 3 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == premiumAccessItemId && + o.Items[1].Deleted == true && + o.Items[2].Id == seatAddOnItemId && + o.Items[2].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndFamilies2025Plan_UpdatesSubscriptionOnlyNoAddons() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2025Plan = new Families2025Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2025 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2025).Returns(families2025Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Discounts == null && + o.ProrationBehavior == ProrationBehavior.None)); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.MonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")))); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Disabled_AndFamilies2025Plan_DoesNotUpdateSubscription() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2025Plan = new Families2025Plan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2025 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2025).Returns(families2025Plan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - should not update subscription or organization when feature flag is disabled + await _stripeFacade.DidNotReceive().UpdateSubscription( + Arg.Any(), + Arg.Any()); + + await _organizationRepository.DidNotReceive().ReplaceAsync( + Arg.Is(org => org.PlanType == PlanType.FamiliesAnnually)); + } + + #region Premium Renewal Email Tests + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndCouponNotFound_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns((Coupon)null); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException + && e.Message == $"Coupon for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} not found"), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndCouponPercentOffIsNull_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + var coupon = new Coupon + { + Id = CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = null + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException + && e.Message == $"coupon.PercentOff for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} is null"), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndValidCoupon_SendsPremiumRenewalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + var coupon = new Coupon + { + Id = CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 30 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + var expectedDiscountedPrice = plan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("user@example.com") && + email.Subject == "Your Bitwarden Premium renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == "30%" && + email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US")) + )); + + await _mailService.DidNotReceive().SendInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndGetCouponThrowsException_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount) + .ThrowsAsync(new StripeException("Stripe API error")); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is StripeException), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + #endregion } diff --git a/test/Core.IntegrationTest/Core.IntegrationTest.csproj b/test/Core.IntegrationTest/Core.IntegrationTest.csproj index 21b746c2fb..133793d3d8 100644 --- a/test/Core.IntegrationTest/Core.IntegrationTest.csproj +++ b/test/Core.IntegrationTest/Core.IntegrationTest.csproj @@ -11,11 +11,11 @@ - - + + - - + + diff --git a/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs index e906862e3f..c874fe58d8 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs @@ -1,6 +1,8 @@ -using System.Text.Json; +using System.Reflection; +using System.Text.Json; using AutoFixture; using AutoFixture.Kernel; +using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; @@ -9,7 +11,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.Data; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.DataProtection; @@ -20,12 +22,24 @@ public class OrganizationCustomization : ICustomization { public bool UseGroups { get; set; } public PlanType PlanType { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + + public OrganizationCustomization() + { + + } + + public OrganizationCustomization(bool useAutomaticUserConfirmation, PlanType planType) + { + UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + PlanType = planType; + } public void Customize(IFixture fixture) { var organizationId = Guid.NewGuid(); var maxCollections = (short)new Random().Next(10, short.MaxValue); - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == PlanType); + var plan = MockPlans.Plans.FirstOrDefault(p => p.Type == PlanType); var seats = (short)new Random().Next(plan.PasswordManager.BaseSeats, plan.PasswordManager.MaxSeats ?? short.MaxValue); var smSeats = plan.SupportsSecretsManager ? (short?)new Random().Next(plan.SecretsManager.BaseSeats, plan.SecretsManager.MaxSeats ?? short.MaxValue) @@ -37,7 +51,8 @@ public class OrganizationCustomization : ICustomization .With(o => o.UseGroups, UseGroups) .With(o => o.PlanType, PlanType) .With(o => o.Seats, seats) - .With(o => o.SmSeats, smSeats)); + .With(o => o.SmSeats, smSeats) + .With(o => o.UseAutomaticUserConfirmation, UseAutomaticUserConfirmation)); fixture.Customize(composer => composer @@ -77,7 +92,7 @@ internal class PaidOrganization : ICustomization public PlanType CheckedPlanType { get; set; } public void Customize(IFixture fixture) { - var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && p.LegacyYear == null).OrderBy(p => p.UpgradeSortOrder).Select(p => p.Type).ToList(); + var validUpgradePlans = MockPlans.Plans.Where(p => p.Type != PlanType.Free && p.LegacyYear == null).OrderBy(p => p.UpgradeSortOrder).Select(p => p.Type).ToList(); var lowestActivePaidPlan = validUpgradePlans.First(); CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; validUpgradePlans.Remove(lowestActivePaidPlan); @@ -105,7 +120,7 @@ internal class FreeOrganizationUpgrade : ICustomization .With(o => o.PlanType, PlanType.Free)); var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; - var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); + var selectedPlan = MockPlans.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); fixture.Customize(composer => composer .With(ou => ou.Plan, selectedPlan.Type) @@ -153,7 +168,7 @@ public class SecretsManagerOrganizationCustomization : ICustomization .With(o => o.Id, organizationId) .With(o => o.UseSecretsManager, true) .With(o => o.PlanType, planType) - .With(o => o.Plan, StaticStore.GetPlan(planType).Name) + .With(o => o.Plan, MockPlans.Get(planType).Name) .With(o => o.MaxAutoscaleSmSeats, (int?)null) .With(o => o.MaxAutoscaleSmServiceAccounts, (int?)null)); } @@ -277,3 +292,9 @@ internal class EphemeralDataProtectionAutoDataAttribute : CustomAutoDataAttribut public EphemeralDataProtectionAutoDataAttribute() : base(new SutProviderCustomization(), new EphemeralDataProtectionCustomization()) { } } + +internal class OrganizationAttribute(bool useAutomaticUserConfirmation = false, PlanType planType = PlanType.Free) : CustomizeAttribute +{ + public override ICustomization GetCustomization(ParameterInfo parameter) => + new OrganizationCustomization(useAutomaticUserConfirmation, planType); +} diff --git a/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs index 634b234e70..53511de550 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs @@ -2,6 +2,7 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Enums; +using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; namespace Bit.Core.Test.AdminConsole.AutoFixture; @@ -9,10 +10,16 @@ namespace Bit.Core.Test.AdminConsole.AutoFixture; internal class OrganizationUserPolicyDetailsCustomization : ICustomization { public PolicyType Type { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType UserType { get; set; } + public bool IsProvider { get; set; } - public OrganizationUserPolicyDetailsCustomization(PolicyType type) + public OrganizationUserPolicyDetailsCustomization(PolicyType type, OrganizationUserStatusType status, OrganizationUserType userType, bool isProvider) { Type = type; + Status = status; + UserType = userType; + IsProvider = isProvider; } public void Customize(IFixture fixture) @@ -20,6 +27,9 @@ internal class OrganizationUserPolicyDetailsCustomization : ICustomization fixture.Customize(composer => composer .With(o => o.OrganizationId, Guid.NewGuid()) .With(o => o.PolicyType, Type) + .With(o => o.OrganizationUserStatus, Status) + .With(o => o.OrganizationUserType, UserType) + .With(o => o.IsProvider, IsProvider) .With(o => o.PolicyEnabled, true)); } } @@ -27,14 +37,25 @@ internal class OrganizationUserPolicyDetailsCustomization : ICustomization public class OrganizationUserPolicyDetailsAttribute : CustomizeAttribute { private readonly PolicyType _type; + private readonly OrganizationUserStatusType _status; + private readonly OrganizationUserType _userType; + private readonly bool _isProvider; - public OrganizationUserPolicyDetailsAttribute(PolicyType type) + public OrganizationUserPolicyDetailsAttribute(PolicyType type) : this(type, OrganizationUserStatusType.Accepted, OrganizationUserType.User, false) { _type = type; } + public OrganizationUserPolicyDetailsAttribute(PolicyType type, OrganizationUserStatusType status, OrganizationUserType userType, bool isProvider) + { + _type = type; + _status = status; + _userType = userType; + _isProvider = isProvider; + } + public override ICustomization GetCustomization(ParameterInfo parameter) { - return new OrganizationUserPolicyDetailsCustomization(_type); + return new OrganizationUserPolicyDetailsCustomization(_type, _status, _userType, _isProvider); } } diff --git a/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs new file mode 100644 index 0000000000..08fcd23969 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs @@ -0,0 +1,866 @@ +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.Services.NoopImplementations; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using NSubstitute; +using StackExchange.Redis; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations; + +public class EventIntegrationServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _services; + private readonly GlobalSettings _globalSettings; + + public EventIntegrationServiceCollectionExtensionsTests() + { + _services = new ServiceCollection(); + _globalSettings = CreateGlobalSettings([]); + + // Add required infrastructure services + _services.TryAddSingleton(_globalSettings); + _services.TryAddSingleton(_globalSettings); + _services.AddLogging(); + + // Mock Redis connection for cache + _services.AddSingleton(Substitute.For()); + + // Mock required repository dependencies for commands + _services.TryAddScoped(_ => Substitute.For()); + _services.TryAddScoped(_ => Substitute.For()); + _services.TryAddScoped(_ => Substitute.For()); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_RegistersAllServices() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + using var provider = _services.BuildServiceProvider(); + + var cache = provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName); + Assert.NotNull(cache); + + var validator = provider.GetRequiredService(); + Assert.NotNull(validator); + + using var scope = provider.CreateScope(); + var sp = scope.ServiceProvider; + + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries_AreRegisteredAsScoped() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var createIntegrationDescriptor = _services.First(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)); + var createConfigDescriptor = _services.First(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); + + Assert.Equal(ServiceLifetime.Scoped, createIntegrationDescriptor.Lifetime); + Assert.Equal(ServiceLifetime.Scoped, createConfigDescriptor.Lifetime); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries_DifferentInstancesPerScope() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var provider = _services.BuildServiceProvider(); + + ICreateOrganizationIntegrationCommand? instance1, instance2, instance3; + using (var scope1 = provider.CreateScope()) + { + instance1 = scope1.ServiceProvider.GetService(); + } + using (var scope2 = provider.CreateScope()) + { + instance2 = scope2.ServiceProvider.GetService(); + } + using (var scope3 = provider.CreateScope()) + { + instance3 = scope3.ServiceProvider.GetService(); + } + + Assert.NotNull(instance1); + Assert.NotNull(instance2); + Assert.NotNull(instance3); + Assert.NotSame(instance1, instance2); + Assert.NotSame(instance2, instance3); + Assert.NotSame(instance1, instance3); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries__SameInstanceWithinScope() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + var provider = _services.BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var instance1 = scope.ServiceProvider.GetService(); + var instance2 = scope.ServiceProvider.GetService(); + + Assert.NotNull(instance1); + Assert.NotNull(instance2); + Assert.Same(instance1, instance2); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var createConfigCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); + Assert.Single(createConfigCmdDescriptors); + + var updateIntegrationCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(IUpdateOrganizationIntegrationCommand)).ToList(); + Assert.Single(updateIntegrationCmdDescriptors); + } + + [Fact] + public void AddOrganizationIntegrationCommandsQueries_RegistersAllIntegrationServices() + { + _services.AddOrganizationIntegrationCommandsQueries(); + + Assert.Contains(_services, s => s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IUpdateOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IDeleteOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IGetOrganizationIntegrationsQuery)); + } + + [Fact] + public void AddOrganizationIntegrationCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddOrganizationIntegrationCommandsQueries(); + _services.AddOrganizationIntegrationCommandsQueries(); + _services.AddOrganizationIntegrationCommandsQueries(); + + var createCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)).ToList(); + Assert.Single(createCmdDescriptors); + } + + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_RegistersAllConfigurationServices() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + Assert.Contains(_services, s => s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IUpdateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IDeleteOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IGetOrganizationIntegrationConfigurationsQuery)); + } + + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + var createCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); + Assert.Single(createCmdDescriptors); + } + + [Fact] + public void IsRabbitMqEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingHostName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingUsername_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingPassword_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = null, + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingExchangeName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingConnectionString_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = null, + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingTopicName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void AddSlackService_AllSettingsPresent_RegistersSlackService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = "test-client-id", + ["GlobalSettings:Slack:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Slack:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddSlackService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = null, + ["GlobalSettings:Slack:ClientSecret"] = null, + ["GlobalSettings:Slack:Scopes"] = null + }); + + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + } + + [Fact] + public void AddTeamsService_AllSettingsPresent_RegistersTeamsServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = "test-client-id", + ["GlobalSettings:Teams:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Teams:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.TryAddScoped(_ => Substitute.For()); + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + + var teamsService = provider.GetService(); + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + + var bot = provider.GetService(); + Assert.NotNull(bot); + Assert.IsType(bot); + + var adapter = provider.GetService(); + Assert.NotNull(adapter); + Assert.IsType(adapter); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddTeamsService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = null, + ["GlobalSettings:Teams:ClientSecret"] = null, + ["GlobalSettings:Teams:Scopes"] = null + }); + + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + var teamsService = provider.GetService(); + + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.TryAddSingleton(TimeProvider.System); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_RegistersCommonServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify common services are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationFilterService)); + Assert.Contains(services, s => s.ServiceType == typeof(TimeProvider)); + + // Verify HttpClients for handlers are registered + var httpClientDescriptors = services.Where(s => s.ServiceType == typeof(IHttpClientFactory)).ToList(); + Assert.NotEmpty(httpClientDescriptors); + } + + [Fact] + public void AddEventIntegrationServices_RegistersIntegrationHandlers() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify integration handlers are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + } + + [Fact] + public void AddEventIntegrationServices_RabbitMqEnabled_RegistersRabbitMqListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for RabbitMQ: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_AzureServiceBusEnabled_RegistersAzureListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_BothEnabled_AzureServiceBusTakesPrecedence() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + // NO RabbitMQ services should be enabled because ASB takes precedence + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_NeitherEnabled_RegistersNoListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register no hosted services when neither RabbitMQ nor Azure Service Bus is enabled + Assert.Equal(0, afterCount - beforeCount); + } + + [Fact] + public void AddEventWriteServices_AzureServiceBusEnabled_RegistersAzureServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_EventsConnectionStringPresent_RegistersAzureQueueService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Events:ConnectionString"] = "DefaultEndpointsProtocol=https;AccountName=test;AccountKey=test;EndpointSuffix=core.windows.net", + ["GlobalSettings:Events:QueueName"] = "event" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(AzureQueueEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_SelfHosted_RegistersRepositoryService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:SelfHosted"] = "true" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(RepositoryEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_NothingEnabled_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(NoopEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_AzureTakesPrecedenceOverRabbitMq() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + services.AddEventWriteServices(globalSettings); + + // Should use Azure Service Bus, not RabbitMQ + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.DoesNotContain(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusEnabled_RegistersAzureServiceBusServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IAzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventRepository)); + Assert.Contains(services, s => s.ServiceType == typeof(AzureTableStorageEventHandler)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddAzureServiceBusListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IRabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(EventRepositoryHandler)); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddRabbitMqListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + + private static GlobalSettings CreateGlobalSettings(Dictionary data) + { + var config = new ConfigurationBuilder() + .AddInMemoryCollection(data) + .Build(); + + var settings = new GlobalSettings(); + config.GetSection("GlobalSettings").Bind(settings); + return settings; + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..c6c8a44955 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,179 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class CreateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task CreateAsync_Success_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_WildcardSuccess_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..3b12f4bd88 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,211 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class DeleteOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_Success_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_WildcardSuccess_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs new file mode 100644 index 0000000000..18541df53e --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs @@ -0,0 +1,101 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class GetOrganizationIntegrationConfigurationsQueryTests +{ + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_Success_ReturnsConfigurations( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + List configurations) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(integrationId); + Assert.Equal(configurations.Count, result.Count); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_NoConfigurations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns([]); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..c2eeefc087 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,390 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class UpdateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_Success_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WildcardSuccess_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedEventType_UpdatesConfigurationAndInvalidatesCacheForBothTypes( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = EventType.Cipher_Created; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + updatedConfiguration.EventType.Value)); + // Verify RemoveByTagAsync was NOT called since both are specific event types + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromWildcardToSpecific_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; // Wildcard + updatedConfiguration.EventType = EventType.User_LoggedIn; // Specific + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromSpecificToWildcard_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; // Specific + updatedConfiguration.EventType = null; // Wildcard + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..62af1eb3ed --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs @@ -0,0 +1,92 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class CreateOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task CreateAsync_Success_CreatesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(integration) + .Returns(integration); + + var result = await sutProvider.Sut.CreateAsync(integration); + + await sutProvider.GetDependency().Received(1) + .GetManyByOrganizationAsync(integration.OrganizationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(integration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + integration.OrganizationId, + integration.Type)); + Assert.Equal(integration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_DuplicateType_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration existingIntegration) + { + integration.Type = IntegrationType.Webhook; + existingIntegration.Type = IntegrationType.Webhook; + existingIntegration.OrganizationId = integration.OrganizationId; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([existingIntegration]); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(integration)); + + Assert.Contains("An integration of this type already exists", exception.Message); + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_DifferentType_Success( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration existingIntegration) + { + integration.Type = IntegrationType.Webhook; + existingIntegration.Type = IntegrationType.Slack; + existingIntegration.OrganizationId = integration.OrganizationId; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([existingIntegration]); + sutProvider.GetDependency() + .CreateAsync(integration) + .Returns(integration); + + var result = await sutProvider.Sut.CreateAsync(integration); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(integration); + Assert.Equal(integration, result); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..25a00bded1 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs @@ -0,0 +1,86 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class DeleteOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_Success_DeletesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(integration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs new file mode 100644 index 0000000000..dfa8e4b306 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs @@ -0,0 +1,44 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class GetOrganizationIntegrationsQueryTests +{ + [Theory, BitAutoData] + public async Task GetManyByOrganizationAsync_CallsRepository( + SutProvider sutProvider, + Guid organizationId, + List integrations) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns(integrations); + + var result = await sutProvider.Sut.GetManyByOrganizationAsync(organizationId); + + await sutProvider.GetDependency().Received(1) + .GetManyByOrganizationAsync(organizationId); + Assert.Equal(integrations.Count, result.Count); + } + + [Theory, BitAutoData] + public async Task GetManyByOrganizationAsync_NoIntegrations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + + var result = await sutProvider.Sut.GetManyByOrganizationAsync(organizationId); + + Assert.Empty(result); + } +} diff --git a/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..fdedec2e51 --- /dev/null +++ b/test/Core.Test/AdminConsole/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs @@ -0,0 +1,121 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.AdminConsole.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class UpdateOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_Success_UpdatesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = organizationId; + existingIntegration.Type = IntegrationType.Webhook; + updatedIntegration.Id = integrationId; + updatedIntegration.OrganizationId = organizationId; + updatedIntegration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedIntegration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + existingIntegration.Type)); + Assert.Equal(updatedIntegration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration updatedIntegration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationIsDifferentType_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = organizationId; + existingIntegration.Type = IntegrationType.Webhook; + updatedIntegration.Id = integrationId; + updatedIntegration.OrganizationId = organizationId; + updatedIntegration.Type = IntegrationType.Hec; // Different Type + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs index cdb109e285..d9a3cd6e8a 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs +++ b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs @@ -2,8 +2,8 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Entities; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; @@ -35,7 +35,7 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void UserName_WhenUserIsSet_ReturnsName(EventMessage eventMessage, User user) + public void UserName_WhenUserIsSet_ReturnsName(EventMessage eventMessage, OrganizationUserUserDetails user) { var sut = new IntegrationTemplateContext(eventMessage) { User = user }; @@ -51,7 +51,7 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void UserEmail_WhenUserIsSet_ReturnsEmail(EventMessage eventMessage, User user) + public void UserEmail_WhenUserIsSet_ReturnsEmail(EventMessage eventMessage, OrganizationUserUserDetails user) { var sut = new IntegrationTemplateContext(eventMessage) { User = user }; @@ -67,7 +67,23 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void ActingUserName_WhenActingUserIsSet_ReturnsName(EventMessage eventMessage, User actingUser) + public void UserType_WhenUserIsSet_ReturnsType(EventMessage eventMessage, OrganizationUserUserDetails user) + { + var sut = new IntegrationTemplateContext(eventMessage) { User = user }; + + Assert.Equal(user.Type, sut.UserType); + } + + [Theory, BitAutoData] + public void UserType_WhenUserIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { User = null }; + + Assert.Null(sut.UserType); + } + + [Theory, BitAutoData] + public void ActingUserName_WhenActingUserIsSet_ReturnsName(EventMessage eventMessage, OrganizationUserUserDetails actingUser) { var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; @@ -83,7 +99,7 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void ActingUserEmail_WhenActingUserIsSet_ReturnsEmail(EventMessage eventMessage, User actingUser) + public void ActingUserEmail_WhenActingUserIsSet_ReturnsEmail(EventMessage eventMessage, OrganizationUserUserDetails actingUser) { var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; @@ -98,6 +114,22 @@ public class IntegrationTemplateContextTests Assert.Null(sut.ActingUserEmail); } + [Theory, BitAutoData] + public void ActingUserType_WhenActingUserIsSet_ReturnsType(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; + + Assert.Equal(actingUser.Type, sut.ActingUserType); + } + + [Theory, BitAutoData] + public void ActingUserType_WhenActingUserIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = null }; + + Assert.Null(sut.ActingUserType); + } + [Theory, BitAutoData] public void OrganizationName_WhenOrganizationIsSet_ReturnsDisplayName(EventMessage eventMessage, Organization organization) { @@ -113,4 +145,20 @@ public class IntegrationTemplateContextTests Assert.Null(sut.OrganizationName); } + + [Theory, BitAutoData] + public void GroupName_WhenGroupIsSet_ReturnsName(EventMessage eventMessage, Group group) + { + var sut = new IntegrationTemplateContext(eventMessage) { Group = group }; + + Assert.Equal(group.Name, sut.GroupName); + } + + [Theory, BitAutoData] + public void GroupName_WhenGroupIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { Group = null }; + + Assert.Null(sut.GroupName); + } } diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs index 916fe981de..50442dd463 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs +++ b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs @@ -17,4 +17,5 @@ public class TestListenerConfiguration : IIntegrationListenerConfiguration public int EventPrefetchCount => 0; public int IntegrationMaxConcurrentCalls => 1; public int IntegrationPrefetchCount => 0; + public string RoutingKey => IntegrationType.ToRoutingKey(); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs index 933bcbc3a1..efcd57b6ad 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.Import; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -57,7 +58,7 @@ public class ImportOrganizationUsersAndGroupsCommandTests var organizationUserRepository = sutProvider.GetDependency(); SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository); - sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id).Returns(existingUsers); sutProvider.GetDependency().GetOccupiedSeatCountByOrganizationIdAsync(org.Id).Returns( new OrganizationSeatCounts diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs index 3f0443d31b..ef4c2c941e 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs @@ -2,7 +2,6 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; @@ -183,17 +182,17 @@ public class VerifyOrganizationDomainCommandTests _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) - .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && - x.OrganizationId == domain.OrganizationId && - x.Enabled && + .SaveAsync(Arg.Is(x => x.PolicyUpdate.Type == PolicyType.SingleOrg && + x.PolicyUpdate.OrganizationId == domain.OrganizationId && + x.PolicyUpdate.Enabled && x.PerformedBy is StandardUser && x.PerformedBy.UserId == userId)); } [Theory, BitAutoData] - public async Task UserVerifyOrganizationDomainAsync_WhenPolicyValidatorsRefactorFlagEnabled_UsesVNextSavePolicyCommand( + public async Task UserVerifyOrganizationDomainAsync_UsesVNextSavePolicyCommand( OrganizationDomain domain, Guid userId, SutProvider sutProvider) { sutProvider.GetDependency() @@ -207,10 +206,6 @@ public class VerifyOrganizationDomainCommandTests sutProvider.GetDependency() .UserId.Returns(userId); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(true); - _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); await sutProvider.GetDependency() @@ -240,9 +235,9 @@ public class VerifyOrganizationDomainCommandTests _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .SaveAsync(Arg.Any()); + .SaveAsync(Arg.Any()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs index 540bac4d1c..82d4eceaed 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs @@ -1,7 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; @@ -24,6 +26,7 @@ using Bit.Test.Common.Fakes; using Microsoft.AspNetCore.DataProtection; using NSubstitute; using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.Test.OrganizationFeatures.OrganizationUsers; @@ -673,6 +676,79 @@ public class AcceptOrgUserCommandTests Assert.Equal("User not found within organization.", exception.Message); } + // Auto-confirm policy validation tests -------------------------------------------------------------------------- + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithAutoConfirmIsNotEnabled_DoesNotCheckCompliance( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + // Act + var resultOrgUser = await sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService); + + // Assert + AssertValidAcceptedOrgUser(resultOrgUser, orgUser, user); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .IsCompliantAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithUserThatIsCompliantWithAutoConfirm_AcceptsUser( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + // Mock auto-confirm enforcement query to return valid (no auto-confirm restrictions) + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user))); + + // Act + var resultOrgUser = await sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService); + + // Assert + AssertValidAcceptedOrgUser(resultOrgUser, orgUser, user); + + await sutProvider.GetDependency().Received(1).ReplaceAsync( + Arg.Is(ou => ou.Id == orgUser.Id && ou.Status == OrganizationUserStatusType.Accepted)); + } + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithAutoConfirmIsEnabledAndFailsCompliance_ThrowsBadRequestException( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails, + OrganizationUser otherOrgUser) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService)); + + // Should get auto-confirm error + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + } + // Private helpers ------------------------------------------------------------------------------------------------- /// @@ -716,7 +792,7 @@ public class AcceptOrgUserCommandTests /// - Provides mock data for an admin to validate email functionality. /// - Returns the corresponding organization for the given org ID. /// - private void SetupCommonAcceptOrgUserMocks(SutProvider sutProvider, User user, + private static void SetupCommonAcceptOrgUserMocks(SutProvider sutProvider, User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) { @@ -729,18 +805,12 @@ public class AcceptOrgUserCommandTests // User is not part of any other orgs sutProvider.GetDependency() .GetManyByUserAsync(user.Id) - .Returns( - Task.FromResult>(new List()) - ); + .Returns([]); // Org they are trying to join does not have single org policy sutProvider.GetDependency() .GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg, OrganizationUserStatusType.Invited) - .Returns( - Task.FromResult>( - new List() - ) - ); + .Returns([]); // User is not part of any organization that applies the single org policy sutProvider.GetDependency() @@ -750,20 +820,24 @@ public class AcceptOrgUserCommandTests // Org does not require 2FA sutProvider.GetDependency().GetPoliciesApplicableToUserAsync(user.Id, PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited) - .Returns(Task.FromResult>( - new List())); + .Returns([]); // Provide at least 1 admin to test email functionality sutProvider.GetDependency() .GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin) - .Returns(Task.FromResult>( - new List() { adminUserDetails } - )); + .Returns([adminUserDetails]); // Return org sutProvider.GetDependency() .GetByIdAsync(org.Id) - .Returns(Task.FromResult(org)); + .Returns(org); + + // Auto-confirm enforcement query returns valid by default (no restrictions) + var request = new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user); + + sutProvider.GetDependency() + .IsCompliantAsync(request) + .Returns(Valid(request)); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs new file mode 100644 index 0000000000..c3fb52ecbe --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs @@ -0,0 +1,639 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Core.Test.AutoFixture.OrganizationFixtures; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUsers; + +[SutProviderCustomize] +public class AutomaticallyConfirmOrganizationUsersValidatorTests +{ + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullOrganizationUser_ReturnsUserNotFoundError( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = null, + OrganizationUserId = Guid.NewGuid(), + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullUserId_ReturnsUserNotFoundError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = null; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullOrganization_ReturnsOrganizationNotFoundError( + SutProvider sutProvider, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = null, + OrganizationId = organizationUser.OrganizationId, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithValidAcceptedUser_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true, planType: PlanType.EnterpriseAnnually)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + Assert.Equal(request, result.Request); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMismatchedOrganizationId_ReturnsOrganizationUserIdIsInvalidError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = Guid.NewGuid(); // Different from organization.Id + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData(OrganizationUserStatusType.Invited)] + [BitAutoData(OrganizationUserStatusType.Revoked)] + [BitAutoData(OrganizationUserStatusType.Confirmed)] + public async Task ValidateAsync_WithNotAcceptedStatus_ReturnsUserIsNotAcceptedError( + OrganizationUserStatusType statusType, + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = statusType; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Custom)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task ValidateAsync_WithNonUserType_ReturnsUserIsNotUserTypeError( + OrganizationUserType userType, + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Type = userType; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWithout2FA_And2FARequired_ReturnsError( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + var twoFactorPolicyDetails = new PolicyDetails + { + OrganizationId = organization.Id, + PolicyType = PolicyType.TwoFactorAuthentication + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, false)]); + + sutProvider.GetDependency() + .GetAsync(userId) + .Returns(new RequireTwoFactorPolicyRequirement([twoFactorPolicyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWith2FA_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWithout2FA_And2FANotRequired_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, false)]); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new RequireTwoFactorPolicyRequirement([])); // No 2FA policy + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserInSingleOrg_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); // Single org + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithAutoConfirmPolicyDisabled_ReturnsAutoConfirmPolicyNotEnabledError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns((Policy)null); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithOrganizationUseAutomaticUserConfirmationDisabled_ReturnsAutoConfirmPolicyNotEnabledError( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: false)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNonProviderUser_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs new file mode 100644 index 0000000000..1035d5c578 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs @@ -0,0 +1,730 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUsers; + +[SutProviderCustomize] +public class AutomaticallyConfirmUsersCommandTests +{ + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithValidRequest_ConfirmsUserSuccessfully( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)); + + await AssertSuccessfulOperationsAsync(sutProvider, organizationUser, organization, user, key); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithInvalidUserOrgId_ReturnsOrganizationUserIdIsInvalidError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = Guid.NewGuid(); // User belongs to another organization + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, false, new OrganizationUserIdIsInvalid()); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + + await sutProvider.GetDependency() + .DidNotReceive() + .ConfirmOrganizationUserAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenAlreadyConfirmed_ReturnsNoneSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + // Return false to indicate the user is already confirmed + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(x => + x.OrganizationUserId == organizationUser.Id && x.Key == request.Key)) + .Returns(false); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .ConfirmOrganizationUserAsync(Arg.Is(x => + x.OrganizationUserId == organizationUser.Id && x.Key == request.Key)); + + // Verify no side effects occurred + await sutProvider.GetDependency() + .DidNotReceive() + .LogOrganizationUserEventAsync(Arg.Any(), Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceive() + .PushSyncOrgKeysAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectionEnabled_CreatesDefaultCollection( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, // Non-empty to trigger creation + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, true); // Policy requires collection + + sutProvider.GetDependency().ConfirmOrganizationUserAsync( + Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync( + Arg.Is(c => + c.OrganizationId == organization.Id && + c.Name == defaultCollectionName && + c.Type == CollectionType.DefaultUserCollection), + Arg.Is>(groups => groups == null), + Arg.Is>(access => + access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null)); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectionDisabled_DoesNotCreateCollection( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = string.Empty, // Empty, so the collection won't be created + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, false); // Policy doesn't require + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .DidNotReceive() + .CreateAsync(Arg.Any(), + Arg.Any>(), + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenCreateDefaultCollectionFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, // Non-empty to trigger creation + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)).Returns(true); + + var collectionException = new Exception("Collection creation failed"); + sutProvider.GetDependency() + .CreateAsync(Arg.Any(), + Arg.Any>(), + Arg.Any>()) + .ThrowsAsync(collectionException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if collection creation fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to create default collection for user")), + collectionException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenEventLogFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var eventException = new Exception("Event logging failed"); + sutProvider.GetDependency() + .LogOrganizationUserEventAsync(Arg.Any(), + EventType.OrganizationUser_AutomaticallyConfirmed, + Arg.Any()) + .ThrowsAsync(eventException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if event log fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to log OrganizationUser_AutomaticallyConfirmed event")), + eventException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenSendEmailFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var emailException = new Exception("Email sending failed"); + sutProvider.GetDependency() + .SendOrganizationConfirmedEmailAsync(organization.Name, user.Email, organizationUser.AccessSecretsManager) + .ThrowsAsync(emailException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if email fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to send OrganizationUserConfirmed")), + emailException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenUserNotFoundForEmail_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Return null when retrieving user for email + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns((User)null!); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if user not found for email + Assert.True(result.IsSuccess); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenDeleteDeviceRegistrationFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName, + Device device) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + device.UserId = user.Id; + device.PushToken = "test-push-token"; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List { device }); + + var deviceException = new Exception("Device registration deletion failed"); + sutProvider.GetDependency() + .DeleteUserRegistrationOrganizationAsync(Arg.Any>(), organization.Id.ToString()) + .ThrowsAsync(deviceException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if device registration deletion fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to delete device registration")), + deviceException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenPushSyncOrgKeysFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var pushException = new Exception("Push sync failed"); + sutProvider.GetDependency() + .PushSyncOrgKeysAsync(user.Id) + .ThrowsAsync(pushException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if push sync fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to push organization keys")), + pushException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDevicesWithoutPushToken_FiltersCorrectly( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName, + Device deviceWithToken, + Device deviceWithoutToken) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + deviceWithToken.UserId = user.Id; + deviceWithToken.PushToken = "test-token"; + deviceWithoutToken.UserId = user.Id; + deviceWithoutToken.PushToken = null; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List { deviceWithToken, deviceWithoutToken }); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .DeleteUserRegistrationOrganizationAsync( + Arg.Is>(devices => + devices.Count(d => deviceWithToken.Id.ToString() == d) == 1), + organization.Id.ToString()); + } + + private static void SetupRepositoryMocks( + SutProvider sutProvider, + OrganizationUser organizationUser, + Organization organization, + User user) + { + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List()); + } + + private static void SetupValidatorMock( + SutProvider sutProvider, + AutomaticallyConfirmOrganizationUserRequest originalRequest, + OrganizationUser organizationUser, + Organization organization, + bool isValid, + Error? error = null) + { + var validationRequest = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = originalRequest.PerformedBy, + DefaultUserCollectionName = originalRequest.DefaultUserCollectionName, + OrganizationUserId = originalRequest.OrganizationUserId, + OrganizationUser = organizationUser, + OrganizationId = originalRequest.OrganizationId, + Organization = organization, + Key = originalRequest.Key + }; + + var validationResult = isValid + ? ValidationResultHelpers.Valid(validationRequest) + : ValidationResultHelpers.Invalid(validationRequest, error ?? new UserIsNotAccepted()); + + sutProvider.GetDependency() + .ValidateAsync(Arg.Any()) + .Returns(validationResult); + } + + private static void SetupPolicyRequirementMock( + SutProvider sutProvider, + Guid userId, + Guid organizationId, + bool requiresDefaultCollection) + { + var policyDetails = requiresDefaultCollection + ? new List { new() { OrganizationId = organizationId } } + : new List(); + + var policyRequirement = new OrganizationDataOwnershipPolicyRequirement( + requiresDefaultCollection ? OrganizationDataOwnershipState.Enabled : OrganizationDataOwnershipState.Disabled, + policyDetails); + + sutProvider.GetDependency() + .GetAsync(userId) + .Returns(policyRequirement); + } + + private static async Task AssertSuccessfulOperationsAsync( + SutProvider sutProvider, + OrganizationUser organizationUser, + Organization organization, + User user, + string key) + { + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync( + Arg.Is(x => x.Id == organizationUser.Id), + EventType.OrganizationUser_AutomaticallyConfirmed, + Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationConfirmedEmailAsync( + organization.Name, + user.Email, + organizationUser.AccessSecretsManager); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(user.Id); + + await sutProvider.GetDependency() + .Received(1) + .DeleteUserRegistrationOrganizationAsync( + Arg.Any>(), + organization.Id.ToString()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index 86b068b88f..5528ecb2a2 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -2,7 +2,9 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -21,6 +23,7 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers; @@ -559,4 +562,256 @@ public class ConfirmOrganizationUserCommandTests .DidNotReceive() .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledAndUserBelongsToAnotherOrg_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, string key, SutProvider sutProvider) + { + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); // Different org + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledForOtherOrg_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new OtherOrganizationDoesNotAllowOtherMembership())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new OtherOrganizationDoesNotAllowOtherMembership().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledAndUserIsProvider_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user), + new ProviderUsersCannotJoin())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new ProviderUsersCannotJoin().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmNotApplicable_Succeeds( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user))); + + // Act + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id); + + // Assert + await sutProvider.GetDependency() + .Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await sutProvider.GetDependency() + .Received(1).SendOrganizationConfirmedEmailAsync(org.DisplayName(), user.Email, orgUser.AccessSecretsManager); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmValidationBeforeSingleOrgPolicy_ChecksAutoConfirmFirst( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, + [OrganizationUserPolicyDetails(PolicyType.SingleOrg)] OrganizationUserPolicyDetails singleOrgPolicy, + string key, SutProvider sutProvider) + { + // Arrange - Setup conditions that would fail BOTH auto-confirm AND single org policy + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + singleOrgPolicy.OrganizationId = org.Id; + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg) + .Returns([singleOrgPolicy]); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + Assert.NotEqual("Cannot confirm this member to the organization until they leave or remove all other organizations.", + exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabled_MixedResults( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, + OrganizationUser otherOrgUser, User user1, User user2, User user3, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser1.UserId = user1.Id; + orgUser2.UserId = user2.Id; + orgUser3.UserId = user3.Id; + otherOrgUser.UserId = user3.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs(orgUsers); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([user1, user2, user3]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser1, orgUser2, orgUser3, otherOrgUser]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user1.Id)) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser1], user1))); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user2.Id)) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser2], user2))); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user3.Id)) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser3, otherOrgUser], user3), + new OtherOrganizationDoesNotAllowOtherMembership())); + + var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id); + + // Assert + Assert.Equal(3, result.Count); + Assert.Empty(result[0].Item2); + Assert.Empty(result[1].Item2); + Assert.Equal(new OtherOrganizationDoesNotAllowOtherMembership().Message, result[2].Item2); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs index c223520a04..dfb1b35be0 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs @@ -1,5 +1,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs new file mode 100644 index 0000000000..caae3a3b12 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs @@ -0,0 +1,113 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +[SutProviderCustomize] +public class BulkResendOrganizationInvitesCommandTests +{ + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_ValidatesUsersAndSendsBatchInvite( + Organization organization, + OrganizationUser validUser1, + OrganizationUser validUser2, + OrganizationUser acceptedUser, + OrganizationUser wrongOrgUser, + SutProvider sutProvider) + { + validUser1.OrganizationId = organization.Id; + validUser1.Status = OrganizationUserStatusType.Invited; + validUser2.OrganizationId = organization.Id; + validUser2.Status = OrganizationUserStatusType.Invited; + acceptedUser.OrganizationId = organization.Id; + acceptedUser.Status = OrganizationUserStatusType.Accepted; + wrongOrgUser.OrganizationId = Guid.NewGuid(); + wrongOrgUser.Status = OrganizationUserStatusType.Invited; + + var users = new List { validUser1, validUser2, acceptedUser, wrongOrgUser }; + var userIds = users.Select(u => u.Id).ToList(); + + sutProvider.GetDependency().GetManyAsync(userIds).Returns(users); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = (await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, userIds)).ToList(); + + Assert.Equal(4, result.Count); + Assert.Equal(2, result.Count(r => string.IsNullOrEmpty(r.Item2))); + Assert.Equal(2, result.Count(r => r.Item2 == "User invalid.")); + + await sutProvider.GetDependency() + .Received(1) + .SendInvitesAsync(Arg.Is(req => + req.Organization == organization && + req.Users.Length == 2 && + req.InitOrganization == false)); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_AllInvalidUsers_DoesNotSendInvites( + Organization organization, + List organizationUsers, + SutProvider sutProvider) + { + foreach (var user in organizationUsers) + { + user.OrganizationId = organization.Id; + user.Status = OrganizationUserStatusType.Confirmed; + } + + var userIds = organizationUsers.Select(u => u.Id).ToList(); + sutProvider.GetDependency().GetManyAsync(userIds).Returns(organizationUsers); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = (await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, userIds)).ToList(); + + Assert.Equal(organizationUsers.Count, result.Count); + Assert.All(result, r => Assert.Equal("User invalid.", r.Item2)); + await sutProvider.GetDependency().DidNotReceive() + .SendInvitesAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_OrganizationNotFound_ThrowsNotFoundException( + Guid organizationId, + List userIds, + List organizationUsers, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetManyAsync(userIds).Returns(organizationUsers); + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkResendInvitesAsync(organizationId, null, userIds)); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_EmptyUserList_ReturnsEmpty( + Organization organization, + SutProvider sutProvider) + { + var emptyUserIds = new List(); + sutProvider.GetDependency().GetManyAsync(emptyUserIds).Returns(new List()); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, emptyUserIds); + + Assert.Empty(result); + await sutProvider.GetDependency().DidNotReceive() + .SendInvitesAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs index 10dcff9e2a..5d82f0717d 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs @@ -13,7 +13,6 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.Commands; using Bit.Core.AdminConsole.Utilities.Errors; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -22,6 +21,7 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Time.Testing; @@ -29,6 +29,7 @@ using NSubstitute; using NSubstitute.ExceptionExtensions; using Xunit; using static Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Helpers.InviteUserOrganizationValidationRequestHelpers; +using Enterprise2019Plan = Bit.Core.Test.Billing.Mocks.Plans.Enterprise2019Plan; namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs index a5b220b94a..e26d9ce978 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs @@ -3,12 +3,12 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -50,7 +50,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -96,7 +96,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -140,7 +140,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 4 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs index be5586f8a6..482b369780 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs @@ -2,7 +2,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Organization; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs index 738ae71298..72a146205b 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs @@ -5,7 +5,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V using Bit.Core.AdminConsole.Utilities.Validation; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs index 571832d675..46ca37522f 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; using Bit.Core.AdminConsole.Utilities.Validation; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs index b16a80d7a2..3c2868d9e3 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs @@ -1,6 +1,6 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs new file mode 100644 index 0000000000..a74135794f --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs @@ -0,0 +1,215 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +[SutProviderCustomize] +public class RevokeOrganizationUserCommandTests +{ + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithValidUsers_RevokesUsersAndLogsEvents( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + orgUser1.UserId = Guid.NewGuid(); + orgUser2.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser1.Id, orgUser2.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser1, orgUser2]); + SetupValidatorMock(sutProvider, [ + ValidationResultHelpers.Valid(orgUser1), + ValidationResultHelpers.Valid(orgUser2) + ]); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Result.IsSuccess)); + + await sutProvider.GetDependency() + .Received(1) + .RevokeManyByIdAsync(Arg.Is>(ids => + ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id))); + + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>( + events => events.Count() == 2)); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(orgUser1.UserId!.Value); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(orgUser2.UserId!.Value); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithSystemUser_LogsEventsWithSystemUserType( + SutProvider sutProvider, + Guid organizationId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(null, false, EventSystemUser.SCIM); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser]); + SetupValidatorMock(sutProvider, [ValidationResultHelpers.Valid(orgUser)]); + + // Act + await sutProvider.Sut.RevokeUsersAsync(request); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>( + events => events.All(e => e.Item3 == EventSystemUser.SCIM))); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithValidationErrors_ReturnsErrorResults( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + + var actingUser = CreateActingUser(actingUserId, false, null); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser1.Id, orgUser2.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser1, orgUser2]); + SetupValidatorMock(sutProvider, [ + ValidationResultHelpers.Invalid(orgUser1, new UserAlreadyRevoked()), + ValidationResultHelpers.Valid(orgUser2) + ]); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + var result1 = results.Single(r => r.Id == orgUser1.Id); + var result2 = results.Single(r => r.Id == orgUser2.Id); + + Assert.True(result1.Result.IsError); + Assert.True(result2.Result.IsSuccess); + + // Only the valid user should be revoked + await sutProvider.GetDependency() + .Received(1) + .RevokeManyByIdAsync(Arg.Is>(ids => + ids.Count() == 1 && ids.Contains(orgUser2.Id))); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WhenPushNotificationFails_ContinuesProcessing( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser]); + SetupValidatorMock(sutProvider, [ValidationResultHelpers.Valid(orgUser)]); + + sutProvider.GetDependency() + .PushSyncOrgKeysAsync(orgUser.UserId!.Value) + .Returns(Task.FromException(new Exception("Push notification failed"))); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results[0].Result.IsSuccess); + + // Should log warning but continue + sutProvider.GetDependency>() + .Received() + .Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + private static IActingUser CreateActingUser(Guid? userId, bool isOwnerOrProvider, EventSystemUser? systemUserType) => + (userId, systemUserType) switch + { + ({ } id, _) => new StandardUser(id, isOwnerOrProvider), + (null, { } type) => new SystemUser(type) + }; + + private static void SetupRepositoryMocks( + SutProvider sutProvider, + ICollection organizationUsers) + { + sutProvider.GetDependency() + .GetManyAsync(Arg.Any>()) + .Returns(organizationUsers); + } + + private static void SetupValidatorMock( + SutProvider sutProvider, + ICollection> validationResults) + { + sutProvider.GetDependency() + .ValidateAsync(Arg.Any()) + .Returns(validationResults); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs new file mode 100644 index 0000000000..fe5802b00b --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs @@ -0,0 +1,325 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +[SutProviderCustomize] +public class RevokeOrganizationUsersValidatorTests +{ + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithValidUsers_ReturnsSuccess( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + orgUser1.UserId = Guid.NewGuid(); + orgUser2.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [orgUser1, orgUser2], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.IsValid)); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithRevokedUser_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser) + { + // Arrange + revokedUser.OrganizationId = organizationId; + revokedUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [revokedUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenRevokingSelf_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = actingUserId; + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [orgUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenNonOwnerRevokesOwner_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + ownerUser.OrganizationId = organizationId; + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenOwnerRevokesOwner_ReturnsSuccess( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + ownerUser.OrganizationId = organizationId; + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, true, null); + var request = CreateValidationRequest( + organizationId, + [ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMultipleUsers_SomeValid_ReturnsMixedResults( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser validUser, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser) + { + // Arrange + validUser.OrganizationId = revokedUser.OrganizationId = organizationId; + validUser.UserId = Guid.NewGuid(); + revokedUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [validUser, revokedUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + + var validResult = results.Single(r => r.Request.Id == validUser.Id); + var errorResult = results.Single(r => r.Request.Id == revokedUser.Id); + + Assert.True(validResult.IsValid); + Assert.True(errorResult.IsError); + Assert.IsType(errorResult.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithSystemUser_DoesNotRequireActingUserId( + SutProvider sutProvider, + Guid organizationId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(null, false, EventSystemUser.SCIM); + var request = CreateValidationRequest( + organizationId, + [orgUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenRevokingLastOwner_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser lastOwner) + { + // Arrange + lastOwner.OrganizationId = organizationId; + lastOwner.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, true, null); // Is an owner + var request = CreateValidationRequest( + organizationId, + [lastOwner], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(false); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMultipleValidationErrors_ReturnsAllErrors( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + revokedUser.OrganizationId = ownerUser.OrganizationId = organizationId; + revokedUser.UserId = Guid.NewGuid(); + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); // Not an owner + var request = CreateValidationRequest( + organizationId, + [revokedUser, ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.IsError)); + + Assert.Contains(results, r => r.AsError is UserAlreadyRevoked); + Assert.Contains(results, r => r.AsError is OnlyOwnersCanRevokeOwners); + } + + private static IActingUser CreateActingUser(Guid? userId, bool isOwnerOrProvider, EventSystemUser? systemUserType) => + (userId, systemUserType) switch + { + ({ } id, _) => new StandardUser(id, isOwnerOrProvider), + (null, { } type) => new SystemUser(type) + }; + + + private static RevokeOrganizationUsersValidationRequest CreateValidationRequest( + Guid organizationId, + ICollection organizationUsers, + IActingUser actingUser) + { + return new RevokeOrganizationUsersValidationRequest( + organizationId, + organizationUsers.Select(u => u.Id).ToList(), + actingUser, + organizationUsers + ); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs index af6b5a17f7..f1c4797de8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs @@ -1,9 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; using Bit.Core.Repositories; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs index feb5ef2a40..c1fea1455e 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -28,7 +28,7 @@ public class CloudICloudOrganizationSignUpCommandTests { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); signup.AdditionalSeats = 0; signup.PaymentMethodType = PaymentMethodType.Card; @@ -37,7 +37,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.IsFromSecretsManagerTrial = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var result = await sutProvider.Sut.SignUpOrganizationAsync(signup); @@ -77,7 +77,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.UseSecretsManager = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); // Extract orgUserId when created Guid? orgUserId = null; @@ -112,7 +112,7 @@ public class CloudICloudOrganizationSignUpCommandTests { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); signup.UseSecretsManager = true; signup.AdditionalSeats = 15; @@ -123,7 +123,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.IsFromSecretsManagerTrial = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var result = await sutProvider.Sut.SignUpOrganizationAsync(signup); @@ -164,7 +164,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.PremiumAccessAddon = false; signup.IsFromProvider = true; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpOrganizationAsync(signup)); Assert.Contains("Organizations with a Managed Service Provider do not support Secrets Manager.", exception.Message); @@ -184,7 +184,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalStorageGb = 0; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -204,7 +204,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalServiceAccounts = 10; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -224,7 +224,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalServiceAccounts = -10; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -244,7 +244,7 @@ public class CloudICloudOrganizationSignUpCommandTests Owner = new User { Id = Guid.NewGuid() } }; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); sutProvider.GetDependency() .GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id) diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs index 881f134b4c..5385b4cdea 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Models.Data; using Bit.Core.Models.StaticStore; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -36,7 +36,7 @@ public class ProviderClientOrganizationSignUpCommandTests signup.AdditionalSeats = 15; signup.CollectionName = collectionName; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); @@ -112,7 +112,7 @@ public class ProviderClientOrganizationSignUpCommandTests signup.Plan = PlanType.TeamsMonthly; signup.AdditionalSeats = -5; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); @@ -132,7 +132,7 @@ public class ProviderClientOrganizationSignUpCommandTests { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs index 55e5698ad4..69f69b1d02 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -172,7 +173,7 @@ public class ResellerClientOrganizationSignUpCommandTests private static async Task AssertCleanupIsPerformed(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .CancelAndRecoverChargesAsync(Arg.Any()); await sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs new file mode 100644 index 0000000000..d547d80aed --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs @@ -0,0 +1,414 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; +using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations; + +[SutProviderCustomize] +public class OrganizationUpdateCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_WhenValidOrganization_UpdatesOrganization( + Guid organizationId, + string name, + string billingEmail, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.GatewayCustomerId = null; // No Stripe customer, but billing update is still called + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = name, + BillingEmail = billingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(name, result.Name); + Assert.Equal(billingEmail.ToLowerInvariant().Trim(), result.BillingEmail); + + await organizationRepository + .Received(1) + .GetByIdAsync(Arg.Is(id => id == organizationId)); + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenOrganizationNotFound_ThrowsNotFoundException( + Guid organizationId, + string name, + string billingEmail, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository + .GetByIdAsync(organizationId) + .Returns((Organization)null); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = name, + BillingEmail = billingEmail + }; + + // Act/Assert + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(request)); + } + + [Theory] + [BitAutoData("")] + [BitAutoData((string)null)] + public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_CallsBillingUpdateButHandledGracefully( + string gatewayCustomerId, + Guid organizationId, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.Name = "Old Name"; + organization.GatewayCustomerId = gatewayCustomerId; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = "New Name", + BillingEmail = organization.BillingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal("New Name", result.Name); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenKeysProvided_AndNotAlreadySet_SetsKeys( + Guid organizationId, + string publicKey, + string encryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.PublicKey = null; + organization.PrivateKey = null; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = organization.Name, + BillingEmail = organization.BillingEmail, + PublicKey = publicKey, + EncryptedPrivateKey = encryptedPrivateKey + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(publicKey, result.PublicKey); + Assert.Equal(encryptedPrivateKey, result.PrivateKey); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenKeysProvided_AndAlreadySet_DoesNotOverwriteKeys( + Guid organizationId, + string newPublicKey, + string newEncryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + + organization.Id = organizationId; + var existingPublicKey = organization.PublicKey; + var existingPrivateKey = organization.PrivateKey; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = organization.Name, + BillingEmail = organization.BillingEmail, + PublicKey = newPublicKey, + EncryptedPrivateKey = newEncryptedPrivateKey + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(existingPublicKey, result.PublicKey); + Assert.Equal(existingPrivateKey, result.PrivateKey); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UpdatingNameOnly_UpdatesNameAndNotBillingEmail( + Guid organizationId, + string newName, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.Name = "Old Name"; + var originalBillingEmail = organization.BillingEmail; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = newName, + BillingEmail = null + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(newName, result.Name); + Assert.Equal(originalBillingEmail, result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UpdatingBillingEmailOnly_UpdatesBillingEmailAndNotName( + Guid organizationId, + string newBillingEmail, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.BillingEmail = "old@example.com"; + var originalName = organization.Name; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = null, + BillingEmail = newBillingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(originalName, result.Name); + Assert.Equal(newBillingEmail.ToLowerInvariant().Trim(), result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenNoChanges_PreservesBothFields( + Guid organizationId, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + var originalName = organization.Name; + var originalBillingEmail = organization.BillingEmail; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = null, + BillingEmail = null + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(originalName, result.Name); + Assert.Equal(originalBillingEmail, result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .DidNotReceiveWithAnyArgs() + .UpdateOrganizationNameAndEmail(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_SelfHosted_OnlyUpdatesKeysNotOrganizationDetails( + Guid organizationId, + string newName, + string newBillingEmail, + string publicKey, + string encryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationBillingService = sutProvider.GetDependency(); + var globalSettings = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + + globalSettings.SelfHosted.Returns(true); + + organization.Id = organizationId; + organization.Name = "Original Name"; + organization.BillingEmail = "original@example.com"; + organization.PublicKey = null; + organization.PrivateKey = null; + + organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = newName, // Should be ignored + BillingEmail = newBillingEmail, // Should be ignored + PublicKey = publicKey, + EncryptedPrivateKey = encryptedPrivateKey + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.Equal("Original Name", result.Name); // Not changed + Assert.Equal("original@example.com", result.BillingEmail); // Not changed + Assert.Equal(publicKey, result.PublicKey); // Changed + Assert.Equal(encryptedPrivateKey, result.PrivateKey); // Changed + + await organizationBillingService + .DidNotReceiveWithAnyArgs() + .UpdateOrganizationNameAndEmail(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs index 37a5627919..47872cc6ab 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs @@ -2,10 +2,10 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Services; using Bit.Core.Models.StaticStore; using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -28,7 +28,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() .AdjustSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); @@ -53,7 +53,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), @@ -81,7 +81,7 @@ public class UpdateOrganizationSubscriptionCommandTests OrganizationSubscriptionUpdate[] subscriptionsToUpdate = [new() { Organization = organization, Plan = new Enterprise2023Plan(true) }]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), Arg.Is(x => x.Type == organization.PlanType), @@ -115,7 +115,7 @@ public class UpdateOrganizationSubscriptionCommandTests new() { Organization = failedOrganization, Plan = new Enterprise2023Plan(true) } ]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == failedOrganization.Id), Arg.Is(x => x.Type == failedOrganization.PlanType), @@ -124,7 +124,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == successfulOrganization.Id), diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs new file mode 100644 index 0000000000..f2e6adbfa9 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs @@ -0,0 +1,306 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +[SutProviderCustomize] +public class AutomaticUserConfirmationPolicyEnforcementValidatorTests +{ + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledAndUserIsProviderMember_ReturnsProviderUsersCannotJoinError( + SutProvider sutProvider, + OrganizationUser organizationUser, + ProviderUser providerUser, + User user) + { + // Arrange + organizationUser.UserId = providerUser.UserId = user.Id; + + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledOnOtherOrganization_ReturnsOtherOrganizationDoesNotAllowOtherMembershipError( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrganizationUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrganizationUser.UserId = user.Id; + + var otherOrgId = Guid.NewGuid(); + var policyDetails = new PolicyDetails + { + OrganizationId = otherOrgId, // Different from organizationUser.OrganizationId + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrganizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledUserIsAMemberOfAnotherOrgReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledUserIsAMemberOfAnotherOrg_ReturnsCannotBeMemberOfAnotherOrgError( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledAndChecksConditionsInCorrectOrder_ReturnsFirstFailure( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + ProviderUser providerUser, + User user) + { + // Arrange + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation, + OrganizationUserId = organizationUser.Id + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyIsEnabledNoOtherOrganizationsAndNotAProvider_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([ + new PolicyDetails + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation, + } + ])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledForCurrentAndOtherOrg_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + otherOrgUser.UserId = organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledForCurrentAndOtherOrgAndIsProvider_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + ProviderUser providerUser, + User user) + { + // Arrange + providerUser.UserId = otherOrgUser.UserId = organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs index 4781127a3d..3c9fd9a9e9 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs @@ -21,52 +21,23 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidat public class AutomaticUserConfirmationPolicyEventHandlerTests { [Theory, BitAutoData] - public async Task ValidateAsync_EnablingPolicy_SingleOrgNotEnabled_ReturnsError( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + public void RequiredPolicies_IncludesSingleOrg( SutProvider sutProvider) { - // Arrange - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns((Policy?)null); - // Act - var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + var requiredPolicies = sutProvider.Sut.RequiredPolicies; // Assert - Assert.Contains("Single organization policy must be enabled", result, StringComparison.OrdinalIgnoreCase); - } - - [Theory, BitAutoData] - public async Task ValidateAsync_EnablingPolicy_SingleOrgPolicyDisabled_ReturnsError( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg, false)] Policy singleOrgPolicy, - SutProvider sutProvider) - { - // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - - // Act - var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); - - // Assert - Assert.Contains("Single organization policy must be enabled", result, StringComparison.OrdinalIgnoreCase); + Assert.Contains(PolicyType.SingleOrg, requiredPolicies); } [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_UsersNotCompliantWithSingleOrg_ReturnsError( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, Guid nonCompliantUserId, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var orgUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -85,10 +56,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Status = OrganizationUserStatusType.Confirmed }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([orgUser]); @@ -107,13 +74,10 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_UserWithInvitedStatusInOtherOrg_ValidationPasses( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, Guid userId, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var orgUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -121,7 +85,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Confirmed, UserId = userId, - Email = "test@email.com" }; var otherOrgUser = new OrganizationUser @@ -133,10 +96,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Email = orgUser.Email }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([orgUser]); @@ -146,7 +105,7 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests .Returns([otherOrgUser]); sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([]); // Act @@ -159,30 +118,37 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_ProviderUsersExist_ReturnsError( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, + Guid userId, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; + var orgUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = userId + }; var providerUser = new ProviderUser { Id = Guid.NewGuid(), ProviderId = Guid.NewGuid(), - UserId = Guid.NewGuid(), + UserId = userId, Status = ProviderUserStatusType.Confirmed }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([]); sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([providerUser]); // Act @@ -196,26 +162,18 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_AllValidationsPassed_ReturnsEmptyString( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var orgUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), OrganizationId = policyUpdate.OrganizationId, Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Confirmed, - UserId = Guid.NewGuid(), - Email = "user@example.com" + UserId = Guid.NewGuid() }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([orgUser]); @@ -225,7 +183,7 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests .Returns([]); sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([]); // Act @@ -249,9 +207,10 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests // Assert Assert.True(string.IsNullOrEmpty(result)); - await sutProvider.GetDependency() + + await sutProvider.GetDependency() .DidNotReceive() - .GetByOrganizationIdTypeAsync(Arg.Any(), Arg.Any()); + .GetManyDetailsByOrganizationAsync(Arg.Any()); } [Theory, BitAutoData] @@ -268,21 +227,18 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests // Assert Assert.True(string.IsNullOrEmpty(result)); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .GetByOrganizationIdTypeAsync(Arg.Any(), Arg.Any()); + .GetManyDetailsByOrganizationAsync(Arg.Any()); } [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_IncludesOwnersAndAdmins_InComplianceCheck( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, Guid nonCompliantOwnerId, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var ownerUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -290,7 +246,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Type = OrganizationUserType.Owner, Status = OrganizationUserStatusType.Confirmed, UserId = nonCompliantOwnerId, - Email = "owner@example.com" }; var otherOrgUser = new OrganizationUser @@ -301,10 +256,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Status = OrganizationUserStatusType.Confirmed }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([ownerUser]); @@ -323,12 +274,9 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_InvitedUsersExcluded_FromComplianceCheck( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var invitedUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -339,16 +287,12 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Email = "invited@example.com" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([invitedUser]); sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([]); // Act @@ -359,14 +303,11 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests } [Theory, BitAutoData] - public async Task ValidateAsync_EnablingPolicy_RevokedUsersExcluded_FromComplianceCheck( + public async Task ValidateAsync_EnablingPolicy_RevokedUsersIncluded_InComplianceCheck( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var revokedUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -374,38 +315,44 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Revoked, UserId = Guid.NewGuid(), - Email = "revoked@example.com" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); + var additionalOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Revoked, + UserId = revokedUser.UserId, + }; - sutProvider.GetDependency() + var orgUserRepository = sutProvider.GetDependency(); + + orgUserRepository .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([revokedUser]); + orgUserRepository.GetManyByManyUsersAsync(Arg.Any>()) + .Returns([additionalOrgUser]); + sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) + .GetManyByManyUsersAsync(Arg.Any>()) .Returns([]); // Act var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); // Assert - Assert.True(string.IsNullOrEmpty(result)); + Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); } [Theory, BitAutoData] public async Task ValidateAsync_EnablingPolicy_AcceptedUsersIncluded_InComplianceCheck( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, Guid nonCompliantUserId, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var acceptedUser = new OrganizationUserUserDetails { Id = Guid.NewGuid(), @@ -413,7 +360,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Accepted, UserId = nonCompliantUserId, - Email = "accepted@example.com" }; var otherOrgUser = new OrganizationUser @@ -424,10 +370,6 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Status = OrganizationUserStatusType.Confirmed }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([acceptedUser]); @@ -443,186 +385,22 @@ public class AutomaticUserConfirmationPolicyEventHandlerTests Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); } - [Theory, BitAutoData] - public async Task ValidateAsync_EnablingPolicy_EmptyOrganization_ReturnsEmptyString( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, - SutProvider sutProvider) - { - // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) - .Returns([]); - - sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) - .Returns([]); - - // Act - var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); - - // Assert - Assert.True(string.IsNullOrEmpty(result)); - } - [Theory, BitAutoData] public async Task ValidateAsync_WithSavePolicyModel_CallsValidateWithPolicyUpdate( [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, SutProvider sutProvider) { // Arrange - singleOrgPolicy.OrganizationId = policyUpdate.OrganizationId; - var savePolicyModel = new SavePolicyModel(policyUpdate); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, PolicyType.SingleOrg) - .Returns(singleOrgPolicy); - sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) .Returns([]); - sutProvider.GetDependency() - .GetManyByOrganizationAsync(policyUpdate.OrganizationId) - .Returns([]); - // Act var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); // Assert Assert.True(string.IsNullOrEmpty(result)); } - - [Theory, BitAutoData] - public async Task OnSaveSideEffectsAsync_EnablingPolicy_SetsUseAutomaticUserConfirmationToTrue( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - Organization organization, - SutProvider sutProvider) - { - // Arrange - organization.Id = policyUpdate.OrganizationId; - organization.UseAutomaticUserConfirmation = false; - - sutProvider.GetDependency() - .GetByIdAsync(policyUpdate.OrganizationId) - .Returns(organization); - - // Act - await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(Arg.Is(o => - o.Id == organization.Id && - o.UseAutomaticUserConfirmation == true && - o.RevisionDate > DateTime.MinValue)); - } - - [Theory, BitAutoData] - public async Task OnSaveSideEffectsAsync_DisablingPolicy_SetsUseAutomaticUserConfirmationToFalse( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation, false)] PolicyUpdate policyUpdate, - Organization organization, - SutProvider sutProvider) - { - // Arrange - organization.Id = policyUpdate.OrganizationId; - organization.UseAutomaticUserConfirmation = true; - - sutProvider.GetDependency() - .GetByIdAsync(policyUpdate.OrganizationId) - .Returns(organization); - - // Act - await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(Arg.Is(o => - o.Id == organization.Id && - o.UseAutomaticUserConfirmation == false && - o.RevisionDate > DateTime.MinValue)); - } - - [Theory, BitAutoData] - public async Task OnSaveSideEffectsAsync_OrganizationNotFound_DoesNotThrowException( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - SutProvider sutProvider) - { - // Arrange - sutProvider.GetDependency() - .GetByIdAsync(policyUpdate.OrganizationId) - .Returns((Organization?)null); - - // Act - await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null); - - // Assert - await sutProvider.GetDependency() - .DidNotReceive() - .UpsertAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task ExecutePreUpsertSideEffectAsync_CallsOnSaveSideEffectsAsync( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy, - Organization organization, - SutProvider sutProvider) - { - // Arrange - organization.Id = policyUpdate.OrganizationId; - currentPolicy.OrganizationId = policyUpdate.OrganizationId; - - var savePolicyModel = new SavePolicyModel(policyUpdate); - - sutProvider.GetDependency() - .GetByIdAsync(policyUpdate.OrganizationId) - .Returns(organization); - - // Act - await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, currentPolicy); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(Arg.Is(o => - o.Id == organization.Id && - o.UseAutomaticUserConfirmation == policyUpdate.Enabled)); - } - - [Theory, BitAutoData] - public async Task OnSaveSideEffectsAsync_UpdatesRevisionDate( - [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, - Organization organization, - SutProvider sutProvider) - { - // Arrange - organization.Id = policyUpdate.OrganizationId; - var originalRevisionDate = DateTime.UtcNow.AddDays(-1); - organization.RevisionDate = originalRevisionDate; - - sutProvider.GetDependency() - .GetByIdAsync(policyUpdate.OrganizationId) - .Returns(organization); - - // Act - await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, null); - - // Assert - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(Arg.Is(o => - o.Id == organization.Id && - o.RevisionDate > originalRevisionDate)); - } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs new file mode 100644 index 0000000000..e317a5886e --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs @@ -0,0 +1,189 @@ +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +[SutProviderCustomize] +public class BlockClaimedDomainAccountCreationPolicyValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_NoVerifiedDomains_ValidationError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Equal("You must claim at least one domain to turn on this policy", result); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_HasVerifiedDomains_Success( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_NoValidation( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, false)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_EnablingPolicy_NoVerifiedDomains_ValidationError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(false); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.Equal("You must claim at least one domain to turn on this policy", result); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_EnablingPolicy_HasVerifiedDomains_Success( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(true); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_NoValidation( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, false)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_FeatureFlagDisabled_ReturnsError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Equal("This feature is not enabled", result); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Fact] + public void Type_ReturnsBlockClaimedDomainAccountCreation() + { + // Arrange + var validator = new BlockClaimedDomainAccountCreationPolicyValidator(null, null); + + // Act & Assert + Assert.Equal(PolicyType.BlockClaimedDomainAccountCreation, validator.Type); + } + + [Fact] + public void RequiredPolicies_ReturnsEmpty() + { + // Arrange + var validator = new BlockClaimedDomainAccountCreationPolicyValidator(null, null); + + // Act + var requiredPolicies = validator.RequiredPolicies.ToList(); + + // Assert + Assert.Empty(requiredPolicies); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs index b1e3faf257..275466a9bd 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs @@ -6,8 +6,11 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; using Bit.Core.Models.Data.Organizations; +using Bit.Core.Platform.Push; using Bit.Core.Services; using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -95,7 +98,8 @@ public class SavePolicyCommandTests Substitute.For(), [new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()], Substitute.For(), - Substitute.For())); + Substitute.For(), + Substitute.For())); Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message); } @@ -360,6 +364,103 @@ public class SavePolicyCommandTests .ExecuteSideEffectsAsync(default!, default!, default!); } + [Theory, BitAutoData] + public async Task VNextSaveAsync_SendsPushNotification( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + // Arrange + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + var savePolicyModel = new SavePolicyModel(policyUpdate); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + // Act + var result = await sutProvider.Sut.VNextSaveAsync(savePolicyModel); + + // Assert + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_SendsPushNotification([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_SendsPushNotificationWithUpdatedPolicy( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + /// /// Returns a new SutProvider with the PolicyValidators registered in the Sut. /// diff --git a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs index 1d94d58aa5..235d597b12 100644 --- a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs +++ b/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs @@ -1,18 +1,23 @@ -using System.Text.Json; +#nullable enable + +using System.Text.Json; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Entities; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; +using ZiggyCreatures.Caching.Fusion; namespace Bit.Core.Test.Services; @@ -20,9 +25,10 @@ namespace Bit.Core.Test.Services; public class EventIntegrationHandlerTests { private const string _templateBase = "Date: #Date#, Type: #Type#, UserId: #UserId#"; + private const string _templateWithGroup = "Group: #GroupName#"; private const string _templateWithOrganization = "Org: #OrganizationName#"; - private const string _templateWithUser = "#UserName#, #UserEmail#"; - private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#"; + private const string _templateWithUser = "#UserName#, #UserEmail#, #UserType#"; + private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#, #ActingUserType#"; private static readonly Guid _organizationId = Guid.NewGuid(); private static readonly Uri _uri = new Uri("https://localhost"); private static readonly Uri _uri2 = new Uri("https://example.com"); @@ -33,19 +39,23 @@ public class EventIntegrationHandlerTests private SutProvider> GetSutProvider( List configurations) { - var configurationCache = Substitute.For(); - configurationCache.GetConfigurationDetails(Arg.Any(), - IntegrationType.Webhook, Arg.Any()).Returns(configurations); + var cache = Substitute.For(); + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>>>(), + options: Arg.Any(), + tags: Arg.Any>() + ).Returns(configurations); return new SutProvider>() - .SetDependency(configurationCache) + .SetDependency(cache) .SetDependency(_eventIntegrationPublisher) .SetDependency(IntegrationType.Webhook) .SetDependency(_logger) .Create(); } - private static IntegrationMessage expectedMessage(string template) + private static IntegrationMessage ExpectedMessage(string template) { return new IntegrationMessage() { @@ -105,16 +115,363 @@ public class EventIntegrationHandlerTests config.Configuration = null; config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); config.Template = _templateBase; - config.Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() { }); + config.Filters = JsonSerializer.Serialize(new IntegrationFilterGroup()); return [config]; } + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserIdPresent_UsesCache(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(actingUser); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Equal(actingUser, context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserOrganizationIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + eventMessage.ActingUserId ??= Guid.NewGuid(); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserFactory_CallsOrganizationUserRepository(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId ??= Guid.NewGuid(); + organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.ActingUserId.Value).Returns(actingUser); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(actingUser); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationUserRepository.Received(1).GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.ActingUserId.Value); + Assert.Equal(actingUser, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupIdPresent_UsesCache(EventMessage eventMessage, Group group) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + + eventMessage.GroupId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(group); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Equal(group, context.Group); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + eventMessage.GroupId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.Group); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupFactory_CallsGroupRepository(EventMessage eventMessage, Group group) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + var groupRepository = sutProvider.GetDependency(); + + eventMessage.GroupId ??= Guid.NewGuid(); + groupRepository.GetByIdAsync(eventMessage.GroupId.Value).Returns(group); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(group); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await groupRepository.Received(1).GetByIdAsync(eventMessage.GroupId.Value); + Assert.Equal(group, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationIdPresent_UsesCache(EventMessage eventMessage, Organization organization) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(organization); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Equal(organization, context.Organization); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.Organization); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationFactory_CallsOrganizationRepository(EventMessage eventMessage, Organization organization) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + organizationRepository.GetByIdAsync(eventMessage.OrganizationId.Value).Returns(organization); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(organization); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationRepository.Received(1).GetByIdAsync(eventMessage.OrganizationId.Value); + Assert.Equal(organization, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserIdPresent_UsesCache(EventMessage eventMessage, OrganizationUserUserDetails userDetails) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(userDetails); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Equal(userDetails, context.User); + } + + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + eventMessage.UserId ??= Guid.NewGuid(); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Null(context.User); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationUserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Null(context.User); + } + + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserFactory_CallsOrganizationUserRepository(EventMessage eventMessage, OrganizationUserUserDetails userDetails) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.UserId.Value).Returns(userDetails); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(userDetails); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationUserRepository.Received(1).GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.UserId.Value); + Assert.Equal(userDetails, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_NoSpecialTokens_DoesNotCallAnyCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.ActingUserId ??= Guid.NewGuid(); + eventMessage.GroupId ??= Guid.NewGuid(); + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateBase); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + } [Theory, BitAutoData] public async Task HandleEventAsync_BaseTemplateNoConfigurations_DoesNothing(EventMessage eventMessage) { var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + cache.GetOrSetAsync>( + Arg.Any(), + Arg.Any>>>(), + Arg.Any() + ).Returns(NoConfigurations()); await sutProvider.Sut.HandleEventAsync(eventMessage); Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); @@ -133,31 +490,32 @@ public class EventIntegrationHandlerTests [Theory, BitAutoData] public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage) { - var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); await sutProvider.Sut.HandleEventAsync(eventMessage); - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetDetailsByOrganizationIdUserIdAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] public async Task HandleEventAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages(EventMessage eventMessage) { - var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); await sutProvider.Sut.HandleEventAsync(eventMessage); - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( @@ -167,77 +525,15 @@ public class EventIntegrationHandlerTests await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_ActingUserTemplate_LoadsUserFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); - var user = Substitute.For(); - user.Email = "test@example.com"; - user.Name = "Test"; - eventMessage.OrganizationId = _organizationId; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"{user.Name}, {user.Email}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.ActingUserId ?? Guid.Empty); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_OrganizationTemplate_LoadsOrganizationFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); - var organization = Substitute.For(); - organization.Name = "Test"; - eventMessage.OrganizationId = _organizationId; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(organization); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"Org: {organization.Name}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.OrganizationId ?? Guid.Empty); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_UserTemplate_LoadsUserFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); - var user = Substitute.For(); - user.Email = "test@example.com"; - user.Name = "Test"; - eventMessage.OrganizationId = _organizationId; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"{user.Name}, {user.Email}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.UserId ?? Guid.Empty); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetDetailsByOrganizationIdUserIdAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] public async Task HandleEventAsync_FilterReturnsFalse_DoesNothing(EventMessage eventMessage) { + eventMessage.OrganizationId = _organizationId; var sutProvider = GetSutProvider(ValidFilterConfiguration()); sutProvider.GetDependency().EvaluateFilterGroup( Arg.Any(), Arg.Any()).Returns(false); @@ -249,14 +545,14 @@ public class EventIntegrationHandlerTests [Theory, BitAutoData] public async Task HandleEventAsync_FilterReturnsTrue_PublishesIntegrationMessage(EventMessage eventMessage) { + eventMessage.OrganizationId = _organizationId; var sutProvider = GetSutProvider(ValidFilterConfiguration()); sutProvider.GetDependency().EvaluateFilterGroup( Arg.Any(), Arg.Any()).Returns(true); - eventMessage.OrganizationId = _organizationId; await sutProvider.Sut.HandleEventAsync(eventMessage); - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); @@ -268,6 +564,7 @@ public class EventIntegrationHandlerTests [Theory, BitAutoData] public async Task HandleEventAsync_InvalidFilter_LogsErrorDoesNothing(EventMessage eventMessage) { + eventMessage.OrganizationId = _organizationId; var sutProvider = GetSutProvider(InvalidFilterConfiguration()); await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -277,12 +574,13 @@ public class EventIntegrationHandlerTests Arg.Any(), Arg.Any(), Arg.Any(), - Arg.Any>()); + Arg.Any>()); } [Theory, BitAutoData] public async Task HandleManyEventsAsync_BaseTemplateNoConfigurations_DoesNothing(List eventMessages) { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); var sutProvider = GetSutProvider(NoConfigurations()); await sutProvider.Sut.HandleManyEventsAsync(eventMessages); @@ -292,13 +590,14 @@ public class EventIntegrationHandlerTests [Theory, BitAutoData] public async Task HandleManyEventsAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessages(List eventMessages) { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); await sutProvider.Sut.HandleManyEventsAsync(eventMessages); foreach (var eventMessage in eventMessages) { - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( + var expectedMessage = ExpectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( @@ -310,13 +609,14 @@ public class EventIntegrationHandlerTests public async Task HandleManyEventsAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages( List eventMessages) { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); await sutProvider.Sut.HandleManyEventsAsync(eventMessages); foreach (var eventMessage in eventMessages) { - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( + var expectedMessage = ExpectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual( @@ -327,4 +627,84 @@ public class EventIntegrationHandlerTests expectedMessage, new[] { "MessageId", "OrganizationId" }))); } } + + [Theory, BitAutoData] + public async Task HandleEventAsync_CapturedFactories_CallConfigurationRepository(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + var configurationRepository = sutProvider.GetDependency(); + + var configs = OneConfiguration(_templateBase); + + configurationRepository.GetManyByEventTypeOrganizationIdIntegrationType(eventType: eventMessage.Type, organizationId: _organizationId, integrationType: IntegrationType.Webhook).Returns(configs); + + // Capture the factory function - there will be 1 call that returns both specific and wildcard matches + Func>, CancellationToken, Task>>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do>, CancellationToken, Task>>>(f + => capturedFactory = f), + options: Arg.Any(), + tags: Arg.Any>() + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + // Verify factory was captured + Assert.NotNull(capturedFactory); + + // Execute the captured factory to trigger repository call + await capturedFactory(null!, CancellationToken.None); + + await configurationRepository.Received(1).GetManyByEventTypeOrganizationIdIntegrationType(eventType: eventMessage.Type, organizationId: _organizationId, integrationType: IntegrationType.Webhook); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_ConfigurationCacheOptions_SetsDurationToConstant(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + + FusionCacheEntryOptions? capturedOption = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>, CancellationToken, Task>>>(), + options: Arg.Do(opt => capturedOption = opt), + tags: Arg.Any?>() + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + Assert.NotNull(capturedOption); + Assert.Equal(EventIntegrationsCacheConstants.DurationForOrganizationIntegrationConfigurationDetails, + capturedOption.Duration); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_ConfigurationCache_AddsOrganizationIntegrationTag(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + + IEnumerable? capturedTags = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>, CancellationToken, Task>>>(), + options: Arg.Any(), + tags: Arg.Do>(t => capturedTags = t) + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + var expectedTag = EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + _organizationId, + IntegrationType.Webhook + ); + Assert.NotNull(capturedTags); + Assert.Contains(expectedTag, capturedTags); + } } diff --git a/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs b/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs deleted file mode 100644 index 4e87d13caf..0000000000 --- a/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs +++ /dev/null @@ -1,173 +0,0 @@ -#nullable enable - -using System.Text.Json; -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.Extensions.Logging; -using NSubstitute; -using NSubstitute.ExceptionExtensions; -using Xunit; - -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class IntegrationConfigurationDetailsCacheServiceTests -{ - private SutProvider GetSutProvider( - List configurations) - { - var configurationRepository = Substitute.For(); - configurationRepository.GetAllConfigurationDetailsAsync().Returns(configurations); - - return new SutProvider() - .SetDependency(configurationRepository) - .Create(); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_SpecificKeyExists_ReturnsExpectedList(OrganizationIntegrationConfigurationDetails config) - { - config.EventType = EventType.Cipher_Created; - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - EventType.Cipher_Created); - Assert.Single(result); - Assert.Same(config, result[0]); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_AllEventsKeyExists_ReturnsExpectedList(OrganizationIntegrationConfigurationDetails config) - { - config.EventType = null; - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - EventType.Cipher_Created); - Assert.Single(result); - Assert.Same(config, result[0]); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_BothSpecificAndAllEventsKeyExists_ReturnsExpectedList( - OrganizationIntegrationConfigurationDetails specificConfig, - OrganizationIntegrationConfigurationDetails allKeysConfig - ) - { - specificConfig.EventType = EventType.Cipher_Created; - allKeysConfig.EventType = null; - allKeysConfig.OrganizationId = specificConfig.OrganizationId; - allKeysConfig.IntegrationType = specificConfig.IntegrationType; - - var sutProvider = GetSutProvider([specificConfig, allKeysConfig]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - specificConfig.OrganizationId, - specificConfig.IntegrationType, - EventType.Cipher_Created); - Assert.Equal(2, result.Count); - Assert.Contains(result, r => r.Template == specificConfig.Template); - Assert.Contains(result, r => r.Template == allKeysConfig.Template); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_KeyMissing_ReturnsEmptyList(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - Guid.NewGuid(), - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Empty(result); - } - - - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_ReturnsCachedValue_EvenIfRepositoryChanges(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - - var newConfig = JsonSerializer.Deserialize(JsonSerializer.Serialize(config)); - Assert.NotNull(newConfig); - newConfig.Template = "Changed"; - sutProvider.GetDependency().GetAllConfigurationDetailsAsync() - .Returns([newConfig]); - - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Single(result); - Assert.NotEqual("Changed", result[0].Template); // should not yet pick up change from repository - - await sutProvider.Sut.RefreshAsync(); // Pick up changes - - result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Single(result); - Assert.Equal("Changed", result[0].Template); // Should have the new value - } - - [Theory, BitAutoData] - public async Task RefreshAsync_GroupsByCompositeKey(OrganizationIntegrationConfigurationDetails config1) - { - var config2 = JsonSerializer.Deserialize( - JsonSerializer.Serialize(config1))!; - config2.Template = "Another"; - - var sutProvider = GetSutProvider([config1, config2]); - await sutProvider.Sut.RefreshAsync(); - - var results = sutProvider.Sut.GetConfigurationDetails( - config1.OrganizationId, - config1.IntegrationType, - config1.EventType ?? EventType.Cipher_Created); - - Assert.Equal(2, results.Count); - Assert.Contains(results, r => r.Template == config1.Template); - Assert.Contains(results, r => r.Template == config2.Template); - } - - [Theory, BitAutoData] - public async Task RefreshAsync_LogsInformationOnSuccess(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - - sutProvider.GetDependency>().Received().Log( - LogLevel.Information, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("Refreshed successfully")), - null, - Arg.Any>()); - } - - [Fact] - public async Task RefreshAsync_OnException_LogsError() - { - var sutProvider = GetSutProvider([]); - sutProvider.GetDependency().GetAllConfigurationDetailsAsync() - .Throws(new Exception("Database failure")); - await sutProvider.Sut.RefreshAsync(); - - sutProvider.GetDependency>().Received(1).Log( - LogLevel.Error, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("Refresh failed")), - Arg.Any(), - Arg.Any>()); - } -} diff --git a/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs new file mode 100644 index 0000000000..1154ad8025 --- /dev/null +++ b/test/Core.Test/AdminConsole/Services/OrganizationIntegrationConfigurationValidatorTests.cs @@ -0,0 +1,244 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Enums; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.Services; + +public class OrganizationIntegrationConfigurationValidatorTests +{ + private readonly OrganizationIntegrationConfigurationValidator _sut = new(); + + [Fact] + public void ValidateConfiguration_CloudBillingSyncIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.CloudBillingSync, configuration)); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyTemplate_ReturnsFalse(string? template) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config2)); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyNonNullConfiguration_ReturnsFalse(string? config) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_NullConfiguration_ReturnsTrue() + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonConfiguration_ReturnsFalse() + { + var config = new OrganizationIntegrationConfiguration + { + Configuration = "{not valid json}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonFilters_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = "template", + Filters = "{Not valid json}" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ScimIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Scim, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidNoAuthWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://localhost"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://example.com"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_UnknownIntegrationType_ReturnsFalse() + { + var unknownType = (IntegrationType)999; + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(unknownType, configuration)); + } +} diff --git a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs index 33f2e78799..43a33cda31 100644 --- a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs @@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -21,8 +22,8 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tokens; -using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Fakes; @@ -618,7 +619,7 @@ public class OrganizationServiceTests SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository); SetupOrgUserRepositoryCreateAsyncMock(organizationUserRepository); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); await sutProvider.Sut.InviteUsersAsync(organization.Id, savingUser.Id, systemUser: null, invites); @@ -666,7 +667,7 @@ public class OrganizationServiceTests .SendInvitesAsync(Arg.Any()).ThrowsAsync(); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); await Assert.ThrowsAsync(async () => await sutProvider.Sut.InviteUsersAsync(organization.Id, savingUser.Id, systemUser: null, invites)); @@ -732,7 +733,7 @@ public class OrganizationServiceTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, seatAdjustment, maxAutoscaleSeats)); @@ -757,7 +758,7 @@ public class OrganizationServiceTests organization.SmSeats = 100; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -837,7 +838,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.EnterpriseMonthly)] public void ValidateSecretsManagerPlan_ThrowsException_WhenNoSecretsManagerSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -854,7 +855,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.Free)] public void ValidateSecretsManagerPlan_ThrowsException_WhenSubtractingSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -871,7 +872,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -890,7 +891,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.EnterpriseMonthly)] public void ValidateSecretsManagerPlan_ThrowsException_WhenMoreSeatsThanPasswordManagerSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -912,7 +913,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -930,7 +931,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -952,7 +953,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -1142,7 +1143,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .Received(1) - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(id => id == organization.GatewayCustomerId), Arg.Is(options => options.Email == requestOptionsReturned.Email && options.Description == requestOptionsReturned.Description @@ -1182,7 +1183,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); await organizationRepository .Received(1) .ReplaceAsync(Arg.Is(org => org == organization)); diff --git a/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs b/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs index d9df9486b6..aee4af346c 100644 --- a/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs +++ b/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs @@ -83,6 +83,7 @@ public class IntegrationTemplateProcessorTests [Theory] [InlineData("User name is #UserName#")] [InlineData("Email: #UserEmail#")] + [InlineData("User type = #UserType#")] public void TemplateRequiresUser_ContainingKeys_ReturnsTrue(string template) { var result = IntegrationTemplateProcessor.TemplateRequiresUser(template); @@ -102,6 +103,7 @@ public class IntegrationTemplateProcessorTests [Theory] [InlineData("Acting user is #ActingUserName#")] [InlineData("Acting user's email is #ActingUserEmail#")] + [InlineData("Acting user's type is #ActingUserType#")] public void TemplateRequiresActingUser_ContainingKeys_ReturnsTrue(string template) { var result = IntegrationTemplateProcessor.TemplateRequiresActingUser(template); @@ -118,6 +120,25 @@ public class IntegrationTemplateProcessorTests Assert.False(result); } + [Theory] + [InlineData("Group name is #GroupName#!")] + [InlineData("Group: #GroupName#")] + public void TemplateRequiresGroup_ContainingKeys_ReturnsTrue(string template) + { + var result = IntegrationTemplateProcessor.TemplateRequiresGroup(template); + Assert.True(result); + } + + [Theory] + [InlineData("#GroupId#")] // This is on the base class, not fetched, so should be false + [InlineData("No Group Tokens")] + [InlineData("")] + public void TemplateRequiresGroup_EmptyInputOrNoMatchingKeys_ReturnsFalse(string template) + { + var result = IntegrationTemplateProcessor.TemplateRequiresGroup(template); + Assert.False(result); + } + [Theory] [InlineData("Organization: #OrganizationName#")] [InlineData("Welcome to #OrganizationName#")] diff --git a/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs b/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs new file mode 100644 index 0000000000..2b9b5cf194 --- /dev/null +++ b/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs @@ -0,0 +1,70 @@ +using Bit.Core.Auth.Attributes; +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Xunit; + +namespace Bit.Core.Test.Auth.Attributes; + +public class MarketingInitiativeValidationAttributeTests +{ + [Fact] + public void IsValid_NullValue_ReturnsTrue() + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(null); + + Assert.True(actual); + } + + [Theory] + [InlineData(MarketingInitiativeConstants.Premium)] + public void IsValid_AcceptedValue_ReturnsTrue(string value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.True(actual); + } + + [Theory] + [InlineData("invalid")] + [InlineData("")] + [InlineData("Premium")] // case sensitive - capitalized + [InlineData("PREMIUM")] // case sensitive - uppercase + [InlineData("premium ")] // trailing space + [InlineData(" premium")] // leading space + public void IsValid_InvalidStringValue_ReturnsFalse(string value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.False(actual); + } + + [Theory] + [InlineData(123)] // integer + [InlineData(true)] // boolean + [InlineData(45.67)] // double + public void IsValid_NonStringValue_ReturnsFalse(object value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.False(actual); + } + + [Fact] + public void ErrorMessage_ContainsAcceptedValues() + { + var sut = new MarketingInitiativeValidationAttribute(); + + var errorMessage = sut.ErrorMessage; + + Assert.NotNull(errorMessage); + Assert.Contains("premium", errorMessage); + Assert.Contains("Marketing initiative type must be one of:", errorMessage); + } +} diff --git a/test/Core.Test/Auth/Entities/AuthRequestTests.cs b/test/Core.Test/Auth/Entities/AuthRequestTests.cs new file mode 100644 index 0000000000..9efeb1ded1 --- /dev/null +++ b/test/Core.Test/Auth/Entities/AuthRequestTests.cs @@ -0,0 +1,224 @@ +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Xunit; + +namespace Bit.Core.Test.Auth.Entities; + +public class AuthRequestTests +{ + [Fact] + public void IsValidForAuthentication_WithValidRequest_ReturnsTrue() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.True(result); + } + + [Fact] + public void IsValidForAuthentication_WithWrongUserId_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var differentUserId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(differentUserId, accessCode); + + // Assert + Assert.False(result, "Auth request should not validate for a different user"); + } + + [Fact] + public void IsValidForAuthentication_WithWrongAccessCode_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = "correct-code" + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, "wrong-code"); + + // Assert + Assert.False(result); + } + + [Fact] + public void IsValidForAuthentication_WithoutResponseDate_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = null, // Not responded to + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Unanswered auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithApprovedFalse_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = false, // Denied + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Denied auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithApprovedNull_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = null, // Pending + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Pending auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithExpiredRequest_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow.AddMinutes(-20), // Expired (15 min timeout) + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Expired auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithWrongType_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.Unlock, // Wrong type + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Only AuthenticateAndUnlock type should be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithAlreadyUsed_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = DateTime.UtcNow, // Already used + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Auth requests should only be valid for one-time use"); + } +} diff --git a/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs b/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs new file mode 100644 index 0000000000..b78e96e91e --- /dev/null +++ b/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs @@ -0,0 +1,18 @@ +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Xunit; + +namespace Bit.Core.Test.Auth.Models.Api.Request.Accounts; + +/// +/// Snapshot tests to ensure the string constants in do not change unintentionally. +/// If you intentionally change any of these values, please update the tests to reflect the new expected values. +/// +public class MarketingInitiativeConstantsSnapshotTests +{ + [Fact] + public void MarketingInitiativeConstants_HaveCorrectValues() + { + // Assert + Assert.Equal("premium", MarketingInitiativeConstants.Premium); + } +} diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs index 7319df17aa..2f4d00a7fa 100644 --- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs @@ -2,7 +2,6 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; @@ -14,7 +13,6 @@ using Bit.Core.Auth.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -342,26 +340,26 @@ public class SsoConfigServiceTests await sutProvider.Sut.SaveAsync(ssoConfig, organization); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.SingleOrg && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.SingleOrg && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.ResetPassword && - t.GetDataModel().AutoEnrollEnabled && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.ResetPassword && + t.PolicyUpdate.GetDataModel().AutoEnrollEnabled && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.RequireSso && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.RequireSso && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); await sutProvider.GetDependency().ReceivedWithAnyArgs() @@ -369,7 +367,7 @@ public class SsoConfigServiceTests } [Theory, BitAutoData] - public async Task SaveAsync_Tde_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand( + public async Task SaveAsync_Tde_UsesVNextSavePolicyCommand( SutProvider sutProvider, Organization organization) { var ssoConfig = new SsoConfig @@ -383,10 +381,6 @@ public class SsoConfigServiceTests OrganizationId = organization.Id, }; - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) - .Returns(true); - await sutProvider.Sut.SaveAsync(ssoConfig, organization); await sutProvider.GetDependency() diff --git a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs index b19ae47cfc..ae669398c5 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs @@ -7,6 +7,7 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.UserFeatures.Registration.Implementations; +using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -37,6 +38,12 @@ public class RegisterUserCommandTests public async Task RegisterUser_Succeeds(SutProvider sutProvider, User user) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user) .Returns(IdentityResult.Success); @@ -61,6 +68,12 @@ public class RegisterUserCommandTests public async Task RegisterUser_WhenCreateUserFails_ReturnsIdentityResultFailed(SutProvider sutProvider, User user) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user) .Returns(IdentityResult.Failed()); @@ -80,6 +93,120 @@ public class RegisterUserCommandTests .SendWelcomeEmailAsync(Arg.Any()); } + // ----------------------------------------------------------------------------------------------- + // RegisterSSOAutoProvisionedUserAsync tests + // ----------------------------------------------------------------------------------------------- + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_Success( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Id = Guid.NewGuid(); + organization.Id = Guid.NewGuid(); + organization.Name = "Test Organization"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_UserRegistrationFails_ReturnsFailedResult( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var expectedError = new IdentityError(); + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Failed(expectedError)); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.False(result.Succeeded); + Assert.Contains(expectedError, result.Errors); + await sutProvider.GetDependency() + .DidNotReceive() + .SendOrganizationUserWelcomeEmailAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData(PlanType.EnterpriseAnnually)] + [BitAutoData(PlanType.EnterpriseMonthly)] + [BitAutoData(PlanType.TeamsAnnually)] + public async Task RegisterSSOAutoProvisionedUserAsync_EnterpriseOrg_SendsOrganizationWelcomeEmail( + PlanType planType, + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = planType; + organization.Name = "Enterprise Org"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserWelcomeEmailAsync(user, organization.Name); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_FeatureFlagDisabled_SendsLegacyWelcomeEmail( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(false); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendWelcomeEmailAsync(user); + } + // ----------------------------------------------------------------------------------------------- // RegisterUserWithOrganizationInviteToken tests // ----------------------------------------------------------------------------------------------- @@ -301,6 +428,138 @@ public class RegisterUserCommandTests Assert.Equal(expectedErrorMessage, exception.Message); } + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromDifferentOrg_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + var blockingOrganizationId = Guid.NewGuid(); // Different org that has the domain blocked + orgUser.OrganizationId = Guid.NewGuid(); // The org they're trying to join + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns(orgUser); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Mock the new overload that excludes the organization - it should return true (domain IS blocked by another org) + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", orgUser.OrganizationId) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromSameOrg_Succeeds( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@company-domain.com"; + user.ReferenceData = null; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + // The organization owns the domain and is trying to invite the user + orgUser.OrganizationId = Guid.NewGuid(); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns(orgUser); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Mock the new overload - it should return false (domain is NOT blocked by OTHER orgs) + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", orgUser.OrganizationId) + .Returns(false); + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", orgUser.OrganizationId); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_WithValidTokenButNullOrgUser_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@example.com"; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Mock GetByIdAsync to return null - simulating a deleted or non-existent organization user + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns((OrganizationUser)null); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId)); + Assert.Equal("Invalid organization user invitation.", exception.Message); + + // Verify that GetByIdAsync was called + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(orgUserId); + + // Verify that user creation was never attempted + await sutProvider.GetDependency() + .DidNotReceive() + .CreateUserAsync(Arg.Any(), Arg.Any()); + } + // ----------------------------------------------------------------------------------------------- // RegisterUserViaEmailVerificationToken tests // ----------------------------------------------------------------------------------------------- @@ -310,6 +569,12 @@ public class RegisterUserCommandTests public async Task RegisterUserViaEmailVerificationToken_Succeeds(SutProvider sutProvider, User user, string masterPasswordHash, string emailVerificationToken, bool receiveMarketingMaterials) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(emailVerificationToken, out Arg.Any()) .Returns(callInfo => @@ -342,6 +607,12 @@ public class RegisterUserCommandTests public async Task RegisterUserViaEmailVerificationToken_InvalidToken_ThrowsBadRequestException(SutProvider sutProvider, User user, string masterPasswordHash, string emailVerificationToken, bool receiveMarketingMaterials) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(emailVerificationToken, out Arg.Any()) .Returns(callInfo => @@ -380,6 +651,12 @@ public class RegisterUserCommandTests string orgSponsoredFreeFamilyPlanInviteToken) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) .Returns((true, new OrganizationSponsorship())); @@ -409,6 +686,12 @@ public class RegisterUserCommandTests string masterPasswordHash, string orgSponsoredFreeFamilyPlanInviteToken) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) .Returns((false, new OrganizationSponsorship())); @@ -446,9 +729,14 @@ public class RegisterUserCommandTests EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; emergencyAccess.Email = user.Email; emergencyAccess.Id = acceptEmergencyAccessId; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) .Returns(callInfo => @@ -482,9 +770,14 @@ public class RegisterUserCommandTests string masterPasswordHash, EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; emergencyAccess.Email = "wrong@email.com"; emergencyAccess.Id = acceptEmergencyAccessId; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) .Returns(callInfo => @@ -525,6 +818,8 @@ public class RegisterUserCommandTests User user, string masterPasswordHash, Guid providerUserId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + // Start with plaintext var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; @@ -547,6 +842,10 @@ public class RegisterUserCommandTests sutProvider.GetDependency() .OrganizationInviteExpirationHours.Returns(120); // 5 days + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user, masterPasswordHash) .Returns(IdentityResult.Success); @@ -576,6 +875,8 @@ public class RegisterUserCommandTests User user, string masterPasswordHash, Guid providerUserId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + // Start with plaintext var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; @@ -598,6 +899,10 @@ public class RegisterUserCommandTests sutProvider.GetDependency() .OrganizationInviteExpirationHours.Returns(120); // 5 days + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + // Using sutProvider in the parameters of the function means that the constructor has already run for the // command so we have to recreate it in order for our mock overrides to be used. sutProvider.Create(); @@ -646,5 +951,521 @@ public class RegisterUserCommandTests Assert.Equal("Open registration has been disabled by the system administrator.", result.Message); } + // ----------------------------------------------------------------------------------------------- + // Domain blocking tests (BlockClaimedDomainAccountCreation policy) + // ----------------------------------------------------------------------------------------------- + [Theory] + [BitAutoData] + public async Task RegisterUser_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUser(user)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + + // Verify user creation was never attempted + await sutProvider.GetDependency() + .DidNotReceive() + .CreateUserAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task RegisterUser_AllowedDomain_Succeeds( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "user@allowed-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowed-domain.com") + .Returns(false); + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterUser(user); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowed-domain.com"); + } + + // SendWelcomeEmail tests + // ----------------------------------------------------------------------------------------------- + [Theory] + [BitAutoData(PlanType.FamiliesAnnually)] + [BitAutoData(PlanType.FamiliesAnnually2019)] + [BitAutoData(PlanType.FamiliesAnnually2025)] + [BitAutoData(PlanType.Free)] + public async Task SendWelcomeEmail_FamilyOrg_SendsFamilyWelcomeEmail( + PlanType planType, + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = planType; + organization.Name = "Family Org"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, organization.Name); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaEmailVerificationToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string emailVerificationToken, bool receiveMarketingMaterials) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(emailVerificationToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new RegistrationEmailVerificationTokenable(user.Email, user.Name, receiveMarketingMaterials); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaEmailVerificationToken(user, masterPasswordHash, emailVerificationToken)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string orgSponsoredFreeFamilyPlanInviteToken) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency() + .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) + .Returns((true, new OrganizationSponsorship())); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken(user, masterPasswordHash, orgSponsoredFreeFamilyPlanInviteToken)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + emergencyAccess.Email = user.Email; + emergencyAccess.Id = acceptEmergencyAccessId; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 10); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaAcceptEmergencyAccessInviteToken(user, masterPasswordHash, acceptEmergencyAccessInviteToken, acceptEmergencyAccessId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaProviderInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, Guid providerUserId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + // Start with plaintext + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + + // Given any byte array, just return the decryptedProviderInviteTokenByteArray (sidestepping any actual encryption) + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(mockDataProtector); + + sutProvider.GetDependency() + .OrganizationInviteExpirationHours.Returns(120); // 5 days + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + // Using sutProvider in the parameters of the function means that the constructor has already run for the + // command so we have to recreate it in order for our mock overrides to be used. + sutProvider.Create(); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaProviderInviteToken(user, masterPasswordHash, base64EncodedProviderInvToken, providerUserId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + // ----------------------------------------------------------------------------------------------- + // Invalid email format tests + // ----------------------------------------------------------------------------------------------- + + [Theory] + [BitAutoData] + public async Task RegisterUser_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "invalid-email-format"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUser(user)); + Assert.Equal("Invalid email address format.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaEmailVerificationToken_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string emailVerificationToken, bool receiveMarketingMaterials) + { + // Arrange + user.Email = "invalid-email-format"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(emailVerificationToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new RegistrationEmailVerificationTokenable(user.Email, user.Name, receiveMarketingMaterials); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaEmailVerificationToken(user, masterPasswordHash, emailVerificationToken)); + Assert.Equal("Invalid email address format.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task SendWelcomeEmail_OrganizationNull_SendsIndividualWelcomeEmail( + User user, + OrganizationUser orgUser, + string orgInviteToken, + string masterPasswordHash, + SutProvider sutProvider) + { + // Arrange + user.ReferenceData = null; + orgUser.Email = user.Email; + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.Id) + .Returns(orgUser); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns((Policy)null); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.OrganizationId) + .Returns((Organization)null); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendIndividualUserWelcomeEmailAsync(user); + } + + [Theory] + [BitAutoData] + public async Task SendWelcomeEmail_OrganizationDisplayNameNull_SendsIndividualWelcomeEmail( + User user, + SutProvider sutProvider) + { + // Arrange + Organization organization = new Organization + { + Name = null + }; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendIndividualUserWelcomeEmailAsync(user); + } + + [Theory] + [BitAutoData] + public async Task GetOrganizationWelcomeEmailDetailsAsync_HappyPath_ReturnsOrganizationWelcomeEmailDetails( + Organization organization, + User user, + OrganizationUser orgUser, + string masterPasswordHash, + string orgInviteToken, + SutProvider sutProvider) + { + // Arrange + user.ReferenceData = null; + orgUser.Email = user.Email; + organization.PlanType = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.Id) + .Returns(orgUser); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns((Policy)null); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.OrganizationId) + .Returns(organization); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id); + + // Assert + Assert.True(result.Succeeded); + + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(orgUser.OrganizationId); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserWelcomeEmailAsync(user, organization.DisplayName()); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithBlockedDomain_ThrowsException( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", organization.Id) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithOwnClaimedDomain_Succeeds( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@company-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Domain is claimed by THIS organization, so it should be allowed + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", organization.Id) + .Returns(false); // Not blocked because organization.Id is excluded + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithNonClaimedDomain_Succeeds( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@unclaimed-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("unclaimed-domain.com", organization.Id) + .Returns(false); // Domain is not claimed by any org + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } } diff --git a/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs index f4f620f8a9..91e8351d2c 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.UserFeatures.Registration.Implementations; using Bit.Core.Entities; using Bit.Core.Exceptions; @@ -21,6 +22,43 @@ public class SendVerificationEmailForRegistrationCommandTests [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsNewUserAndEnableEmailVerificationTrue_SendsEmailAndReturnsNull(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .GetByEmailAsync(email) + .ReturnsNull(); + + sutProvider.GetDependency() + .EnableEmailVerification = true; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + + var mockedToken = "token"; + sutProvider.GetDependency>() + .Protect(Arg.Any()) + .Returns(mockedToken); + + // Act + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendRegistrationVerificationEmailAsync(email, mockedToken, null); + Assert.Null(result); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenFromMarketingIsPremium_SendsEmailWithMarketingParameterAndReturnsNull(SutProvider sutProvider, string email, string name, bool receiveMarketingEmails) { // Arrange @@ -34,31 +72,35 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; - sutProvider.GetDependency() - .SendRegistrationVerificationEmailAsync(email, Arg.Any()) - .Returns(Task.CompletedTask); + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); + var fromMarketing = MarketingInitiativeConstants.Premium; + // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, fromMarketing); // Assert await sutProvider.GetDependency() .Received(1) - .SendRegistrationVerificationEmailAsync(email, mockedToken); + .SendRegistrationVerificationEmailAsync(email, mockedToken, fromMarketing); Assert.Null(result); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsExistingUserAndEnableEmailVerificationTrue_ReturnsNull(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .Returns(new User()); @@ -69,27 +111,33 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); // Assert await sutProvider.GetDependency() .DidNotReceive() - .SendRegistrationVerificationEmailAsync(email, mockedToken); + .SendRegistrationVerificationEmailAsync(email, mockedToken, null); Assert.Null(result); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsNewUserAndEnableEmailVerificationFalse_ReturnsToken(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .ReturnsNull(); @@ -100,13 +148,17 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); // Assert Assert.Equal(mockedToken, result); @@ -122,15 +174,17 @@ public class SendVerificationEmailForRegistrationCommandTests .DisableUserRegistration = true; // Act & Assert - await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsExistingUserAndEnableEmailVerificationFalse_ThrowsBadRequestException(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .Returns(new User()); @@ -138,8 +192,15 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .EnableEmailVerification = false; + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + // Act & Assert - await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); } [Theory] @@ -150,7 +211,7 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; - await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run(null, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run(null, name, receiveMarketingEmails, null)); } [Theory] @@ -160,6 +221,90 @@ public class SendVerificationEmailForRegistrationCommandTests { sutProvider.GetDependency() .DisableUserRegistration = false; - await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run("", name, receiveMarketingEmails)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run("", name, receiveMarketingEmails, null)); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenBlockedDomain_ThrowsBadRequestException(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@blockedcompany.com"; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blockedcompany.com") + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenAllowedDomain_Succeeds(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@allowedcompany.com"; + + sutProvider.GetDependency() + .GetByEmailAsync(email) + .ReturnsNull(); + + sutProvider.GetDependency() + .EnableEmailVerification = false; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowedcompany.com") + .Returns(false); + + var mockedToken = "token"; + sutProvider.GetDependency>() + .Protect(Arg.Any()) + .Returns(mockedToken); + + // Act + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); + + // Assert + Assert.Equal(mockedToken, result); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = "invalid-email-format"; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); + Assert.Equal("Invalid email address format.", exception.Message); } } diff --git a/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs new file mode 100644 index 0000000000..2b448ba79f --- /dev/null +++ b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs @@ -0,0 +1,275 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Auth.UserFeatures.Sso; + +[SutProviderCustomize] +public class UserSsoOrganizationIdentifierQueryTests +{ + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasSingleConfirmedOrganization_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = "test-org-identifier"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("test-org-identifier", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasNoOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId) + { + // Arrange + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(Array.Empty()); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMultipleConfirmedOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser1, + OrganizationUser organizationUser2) + { + // Arrange + organizationUser1.UserId = userId; + organizationUser1.Status = OrganizationUserStatusType.Confirmed; + organizationUser2.UserId = userId; + organizationUser2.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser1, organizationUser2]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserStatusType.Invited)] + [BitAutoData(OrganizationUserStatusType.Accepted)] + [BitAutoData(OrganizationUserStatusType.Revoked)] + public async Task GetSsoOrganizationIdentifierAsync_UserHasOnlyInvitedOrganization_ReturnsNull( + OrganizationUserStatusType status, + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = status; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMixedStatusOrganizations_OnlyOneConfirmed_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser confirmedOrgUser, + OrganizationUser invitedOrgUser, + OrganizationUser revokedOrgUser) + { + // Arrange + confirmedOrgUser.UserId = userId; + confirmedOrgUser.OrganizationId = organization.Id; + confirmedOrgUser.Status = OrganizationUserStatusType.Confirmed; + + invitedOrgUser.UserId = userId; + invitedOrgUser.Status = OrganizationUserStatusType.Invited; + + revokedOrgUser.UserId = userId; + revokedOrgUser.Status = OrganizationUserStatusType.Revoked; + + organization.Identifier = "mixed-status-org"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { confirmedOrgUser, invitedOrgUser, revokedOrgUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("mixed-status-org", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationNotFound_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.OrganizationId) + .Returns((Organization)null); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organizationUser.OrganizationId); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsNull_ReturnsNull( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = null; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsEmpty_ReturnsEmpty( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = string.Empty; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal(string.Empty, result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } +} diff --git a/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs b/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs index adeac45d06..3a98fb44fb 100644 --- a/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs +++ b/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs @@ -1,10 +1,13 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth; +using Bit.Core.Billing.Premium.Queries; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -404,6 +407,277 @@ public class TwoFactorIsEnabledQueryTests .GetCalculatedPremiumAsync(default); } + [Theory] + [BitAutoData((IEnumerable)null)] + [BitAutoData([])] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsEmpty( + IEnumerable userIds, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert + Assert.Empty(result); + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Duo)] + [BitAutoData(TwoFactorProviderType.YubiKey)] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithMixedScenarios_ReturnsCorrectResults( + TwoFactorProviderType premiumProviderType, + SutProvider sutProvider, + User user1, + User user2, + User user3) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + // User 1: Non-premium provider → 2FA enabled + user1.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } } + }); + + // User 2: Premium provider + has premium → 2FA enabled + user2.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + // User 3: Premium provider + no premium → 2FA disabled + user3.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + var premiumStatus = new Dictionary + { + { user2.Id, true }, + { user3.Id, false } + }; + + sutProvider.GetDependency() + .GetManyAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))) + .Returns(premiumStatus); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert + Assert.Contains(result, res => res.userId == user1.Id && res.twoFactorIsEnabled == true); // Non-premium provider + Assert.Contains(result, res => res.userId == user2.Id && res.twoFactorIsEnabled == true); // Premium + has premium + Assert.Contains(result, res => res.userId == user3.Id && res.twoFactorIsEnabled == false); // Premium + no premium + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Duo)] + [BitAutoData(TwoFactorProviderType.YubiKey)] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_OnlyChecksPremiumAccessForUsersWhoNeedIt( + TwoFactorProviderType premiumProviderType, + SutProvider sutProvider, + User user1, + User user2, + User user3) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + // User 1: Has non-premium provider - should NOT trigger premium check + user1.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } } + }); + + // User 2 & 3: Have only premium providers - SHOULD trigger premium check + user2.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + user3.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + var premiumStatus = new Dictionary + { + { user2.Id, true }, + { user3.Id, false } + }; + + sutProvider.GetDependency() + .GetManyAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))) + .Returns(premiumStatus); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert - Verify optimization: premium checked ONLY for users 2 and 3 (not user 1) + await sutProvider.GetDependency() + .Received(1) + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsAllTwoFactorDisabled( + SutProvider sutProvider, + List users) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + foreach (var user in users) + { + user.UserId = null; + } + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(users); + + // Assert + foreach (var user in users) + { + Assert.Contains(result, res => res.user.Equals(user) && res.twoFactorIsEnabled == false); + } + + // No UserIds were supplied so no calls to the UserRepository should have been made + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetManyAsync(default); + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Authenticator, true)] // Non-premium provider + [BitAutoData(TwoFactorProviderType.Duo, true)] // Premium provider with premium access + [BitAutoData(TwoFactorProviderType.YubiKey, false)] // Premium provider without premium access + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_SingleUser_VariousScenarios( + TwoFactorProviderType providerType, + bool hasPremiumAccess, + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.SetTwoFactorProviders(new Dictionary + { + { providerType, new TwoFactorProvider { Enabled = true } } + }); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(user.Id) + .Returns(hasPremiumAccess); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + var requiresPremium = TwoFactorProvider.RequiresPremium(providerType); + var expectedResult = !requiresPremium || hasPremiumAccess; + Assert.Equal(expectedResult, result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoEnabledProviders_ReturnsFalse( + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Email, new TwoFactorProvider { Enabled = false } } + }); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + Assert.False(result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNullProviders_ReturnsFalse( + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.TwoFactorProviders = null; + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + Assert.False(result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_UserNotFound_ThrowsNotFoundException( + SutProvider sutProvider, + Guid userId) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var testUser = new TestTwoFactorProviderUser + { + Id = userId, + TwoFactorProviders = null + }; + + sutProvider.GetDependency() + .GetByIdAsync(userId) + .Returns((User)null); + + // Act & Assert + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.TwoFactorIsEnabledAsync(testUser)); + } + private class TestTwoFactorProviderUser : ITwoFactorProvidersUser { public Guid? Id { get; set; } @@ -418,10 +692,5 @@ public class TwoFactorIsEnabledQueryTests { return Id; } - - public bool GetPremium() - { - return Premium; - } } } diff --git a/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs index 65d9e99e3b..1a4f92a224 100644 --- a/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs +++ b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.Billing.Extensions; +using System.Globalization; +using Bit.Core.Billing.Extensions; using Stripe; using Xunit; @@ -356,9 +357,18 @@ public class InvoiceExtensionsTests [Fact] public void FormatForProvider_ComplexScenario_HandlesAllLineTypes() { - // Arrange - var lineItems = new StripeList(); - lineItems.Data = new List + // Set culture to en-US to ensure consistent decimal formatting in tests + // This ensures tests pass on all machines regardless of system locale + var originalCulture = Thread.CurrentThread.CurrentCulture; + var originalUICulture = Thread.CurrentThread.CurrentUICulture; + try + { + Thread.CurrentThread.CurrentCulture = new CultureInfo("en-US"); + Thread.CurrentThread.CurrentUICulture = new CultureInfo("en-US"); + + // Arrange + var lineItems = new StripeList(); + lineItems.Data = new List { new InvoiceLineItem { @@ -372,23 +382,29 @@ public class InvoiceExtensionsTests new InvoiceLineItem { Description = "Custom Service", Quantity = 2, Amount = 2000 } }; - var invoice = new Invoice + var invoice = new Invoice + { + Lines = lineItems, + TotalTaxes = [new InvoiceTotalTax { Amount = 200 }] // Additional $2.00 tax + }; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Equal(5, result.Count); + Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); + Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[1]); + Assert.Equal("1 × Tax (at $8.00 / month)", result[2]); + Assert.Equal("Custom Service", result[3]); + Assert.Equal("1 × Tax (at $2.00 / month)", result[4]); + } + finally { - Lines = lineItems, - TotalTaxes = [new InvoiceTotalTax { Amount = 200 }] // Additional $2.00 tax - }; - var subscription = new Subscription(); - - // Act - var result = invoice.FormatForProvider(subscription); - - // Assert - Assert.Equal(5, result.Count); - Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); - Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[1]); - Assert.Equal("1 × Tax (at $8.00 / month)", result[2]); - Assert.Equal("Custom Service", result[3]); - Assert.Equal("1 × Tax (at $2.00 / month)", result[4]); + Thread.CurrentThread.CurrentCulture = originalCulture; + Thread.CurrentThread.CurrentUICulture = originalUICulture; + } } #endregion diff --git a/test/Core.Test/Billing/Mocks/MockPlans.cs b/test/Core.Test/Billing/Mocks/MockPlans.cs new file mode 100644 index 0000000000..b4737434fb --- /dev/null +++ b/test/Core.Test/Billing/Mocks/MockPlans.cs @@ -0,0 +1,37 @@ +using Bit.Core.Billing.Enums; +using Bit.Core.Models.StaticStore; +using Bit.Core.Test.Billing.Mocks.Plans; + +namespace Bit.Core.Test.Billing.Mocks; + +public class MockPlans +{ + public static List Plans => + [ + new CustomPlan(), + new Enterprise2019Plan(false), + new Enterprise2019Plan(true), + new Enterprise2020Plan(false), + new Enterprise2020Plan(true), + new Enterprise2023Plan(false), + new Enterprise2023Plan(true), + new EnterprisePlan(false), + new EnterprisePlan(true), + new Families2019Plan(), + new Families2025Plan(), + new FamiliesPlan(), + new FreePlan(), + new Teams2019Plan(false), + new Teams2019Plan(true), + new Teams2020Plan(false), + new Teams2020Plan(true), + new Teams2023Plan(false), + new Teams2023Plan(true), + new TeamsPlan(false), + new TeamsPlan(true), + new TeamsStarterPlan(), + new TeamsStarterPlan2023() + ]; + + public static Plan Get(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType)!; +} diff --git a/src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs b/test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs similarity index 89% rename from src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs index ce55cb422e..0105b7d07f 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record CustomPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs index b584647a26..27f3710b96 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2019Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs index a1a6113cbc..8f56125fc1 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2020Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs rename to test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs index 8aeca521d1..563adc82a3 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record EnterprisePlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs index dce1719a49..f221821ed3 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2023Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs similarity index 96% rename from src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs index 93ab2c39a1..a0257d88e9 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Families2019Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs similarity index 95% rename from src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs index 77e238e98e..5f5424bbcf 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Families2025Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs b/test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs similarity index 95% rename from src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs index b2edc1168b..70aa613ee0 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record FamiliesPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs b/test/Core.Test/Billing/Mocks/Plans/FreePlan.cs similarity index 95% rename from src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs rename to test/Core.Test/Billing/Mocks/Plans/FreePlan.cs index 3b0a8b7480..307f58c803 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/FreePlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record FreePlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs index 27ed5e0bf4..f1aad7c16f 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2019Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs index a760b9692e..546f1f84c5 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2020Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs index 654792ee0b..e0ecd35346 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs index 8498af6b13..5ec2acd61c 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2023Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs similarity index 97% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs index d78844e429..119f431a56 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsStarterPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs similarity index 97% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs index ea15d9eb95..40952e75fb 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsStarterPlan2023 : Plan { diff --git a/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs b/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs index 04b579add3..d1f02af50d 100644 --- a/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs +++ b/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs @@ -213,7 +213,8 @@ If you believe you need to change the version for a valid reason, please discuss LimitCollectionDeletion = true, AllowAdminAccessToAllCollectionItems = true, UseOrganizationDomains = true, - UseAdminSponsoredFamilies = false + UseAdminSponsoredFamilies = false, + UsePhishingBlocker = false, }; } diff --git a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs index 8b3a044118..2f278dcd20 100644 --- a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; @@ -58,7 +58,7 @@ public class PreviewOrganizationTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -68,7 +68,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for sponsored subscription - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -116,7 +116,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8250 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -126,7 +126,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(82.50m, total); // Verify the correct Stripe API call for standalone secrets manager - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -179,7 +179,7 @@ public class PreviewOrganizationTaxCommandTests Total = 12200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -189,7 +189,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(122.00m, total); // Verify the correct Stripe API call for comprehensive purchase with storage and service accounts - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -240,7 +240,7 @@ public class PreviewOrganizationTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -250,7 +250,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify the correct Stripe API call for Families tier (non-seat-based plan) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -292,7 +292,7 @@ public class PreviewOrganizationTaxCommandTests Total = 2700 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -302,7 +302,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(27.00m, total); // Verify the correct Stripe API call for business use in non-US country (tax exempt reverse) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -345,7 +345,7 @@ public class PreviewOrganizationTaxCommandTests Total = 12100 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -355,7 +355,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(121.00m, total); // Verify the correct Stripe API call for Spanish NIF that adds both Spanish NIF and EU VAT tax IDs - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -405,7 +405,7 @@ public class PreviewOrganizationTaxCommandTests Total = 1320 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -415,7 +415,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(13.20m, total); // Verify the correct Stripe API call for free organization upgrade to Teams - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -458,7 +458,7 @@ public class PreviewOrganizationTaxCommandTests Total = 4400 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -468,7 +468,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(44.00m, total); // Verify the correct Stripe API call for free organization upgrade to Families (no SM for Families) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -522,7 +522,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -534,7 +534,7 @@ public class PreviewOrganizationTaxCommandTests Total = 9900 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -543,7 +543,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(9.00m, tax); Assert.Equal(99.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -597,7 +597,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -609,7 +609,7 @@ public class PreviewOrganizationTaxCommandTests Total = 13200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -618,7 +618,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(12.00m, tax); Assert.Equal(132.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -661,7 +661,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -671,7 +671,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for free organization with SM to Enterprise - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -730,7 +730,7 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -738,7 +738,7 @@ public class PreviewOrganizationTaxCommandTests Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -748,7 +748,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for existing subscription upgrade - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -814,7 +814,7 @@ public class PreviewOrganizationTaxCommandTests } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -822,7 +822,7 @@ public class PreviewOrganizationTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -832,7 +832,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call preserves existing discount - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -876,8 +876,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -919,7 +919,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -927,7 +927,7 @@ public class PreviewOrganizationTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -937,7 +937,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call for PM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -984,7 +984,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -992,7 +992,7 @@ public class PreviewOrganizationTaxCommandTests Total = 13200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1002,7 +1002,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(132.00m, total); // Verify the correct Stripe API call for PM seats + storage - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -1051,7 +1051,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1059,7 +1059,7 @@ public class PreviewOrganizationTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1069,7 +1069,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for SM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -1119,7 +1119,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1127,7 +1127,7 @@ public class PreviewOrganizationTaxCommandTests Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1137,7 +1137,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for SM seats + service accounts with tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -1200,7 +1200,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1208,7 +1208,7 @@ public class PreviewOrganizationTaxCommandTests Total = 27500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1218,7 +1218,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(275.00m, total); // Verify the correct Stripe API call for comprehensive update with discount and Spanish tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -1276,7 +1276,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1284,7 +1284,7 @@ public class PreviewOrganizationTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1294,7 +1294,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for Families tier (personal usage, no business tax exemption) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -1334,8 +1334,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } [Fact] @@ -1378,7 +1378,7 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { @@ -1386,7 +1386,7 @@ public class PreviewOrganizationTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1396,7 +1396,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify only PM seats are included (storage=0 excluded, SM seats=0 so entire SM excluded) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && diff --git a/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs index ea76f9d975..4cb4caae46 100644 --- a/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs @@ -88,7 +88,7 @@ public class UpdateOrganizationLicenseCommandTests "Hash", "Signature", "SignatureBytes", "InstallationId", "Expires", "ExpirationWithoutGracePeriod", "Token", "LimitCollectionCreationDeletion", "LimitCollectionCreation", "LimitCollectionDeletion", "AllowAdminAccessToAllCollectionItems", - "UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation") && + "UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation", "UsePhishingBlocker") && // Same property but different name, use explicit mapping org.ExpirationDate == license.Expires)); } diff --git a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs index 617a136fab..0ceb257c88 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs @@ -8,7 +8,6 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Platform.Installations; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture; using Bit.Core.Test.Billing.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -59,7 +58,7 @@ public class GetCloudOrganizationLicenseQueryTests { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); @@ -80,7 +79,7 @@ public class GetCloudOrganizationLicenseQueryTests { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); sutProvider.GetDependency() .CreateOrganizationTokenAsync(organization, installationId, subInfo) @@ -119,7 +118,7 @@ public class GetCloudOrganizationLicenseQueryTests installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); - sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs index 9f4b8474b5..e4cb0b0109 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs @@ -8,7 +8,7 @@ using Bit.Core.Billing.Services; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Settings; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -163,7 +163,7 @@ public class GetOrganizationMetadataQueryTests sutProvider.GetDependency() .GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var result = await sutProvider.Sut.Run(organization); @@ -216,7 +216,7 @@ public class GetOrganizationMetadataQueryTests sutProvider.GetDependency() .GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var result = await sutProvider.Sut.Run(organization); @@ -282,7 +282,7 @@ public class GetOrganizationMetadataQueryTests sutProvider.GetDependency() .GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var result = await sutProvider.Sut.Run(organization); @@ -349,7 +349,7 @@ public class GetOrganizationMetadataQueryTests sutProvider.GetDependency() .GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var result = await sutProvider.Sut.Run(organization); diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs index 05d24bdc34..a7284410fe 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs @@ -8,7 +8,6 @@ using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -382,7 +381,7 @@ public class GetOrganizationWarningsQueryTests var dueDate = now.AddDays(-10); - sutProvider.GetDependency().InvoiceSearchAsync(Arg.Is(options => + sutProvider.GetDependency().SearchInvoiceAsync(Arg.Is(options => options.Query == $"subscription:'{subscriptionId}' status:'open'")).Returns([ new Invoice { DueDate = dueDate, Created = dueDate.AddDays(-30) } ]); @@ -542,7 +541,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -583,7 +582,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -635,7 +634,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -687,7 +686,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -739,7 +738,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -785,7 +784,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs index c42049d5bb..5854d1c3b5 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Microsoft.Extensions.Logging; using NSubstitute; @@ -73,7 +72,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -84,7 +83,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -131,7 +130,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -144,7 +143,7 @@ public class UpdateBillingAddressCommandTests await _subscriberService.Received(1).CreateStripeCustomer(organization); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -192,7 +191,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -204,7 +203,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -260,7 +259,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -272,10 +271,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdDeleteAsync(customer.Id, "tax_id_123"); + await _stripeAdapter.Received(1).DeleteTaxIdAsync(customer.Id, "tax_id_123"); } [Fact] @@ -322,7 +321,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse @@ -334,7 +333,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -384,14 +383,14 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse )).Returns(customer); _stripeAdapter - .TaxIdCreateAsync(customer.Id, + .CreateTaxIdAsync(customer.Id, Arg.Is(options => options.Type == TaxIdType.EUVAT)) .Returns(new TaxId { Type = TaxIdType.EUVAT, Value = "ESA12345678" }); @@ -401,10 +400,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input with { TaxId = new TaxID(TaxIdType.EUVAT, "ESA12345678") }, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdCreateAsync(organization.GatewayCustomerId, Arg.Is( + await _stripeAdapter.Received(1).CreateTaxIdAsync(organization.GatewayCustomerId, Arg.Is( options => options.Type == TaxIdType.SpanishNIF && options.Value == input.TaxId.Value)); } diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs index 72280c4c77..da42127f33 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.Billing.Extensions; using Braintree; @@ -82,7 +81,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -144,7 +143,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -213,7 +212,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -232,7 +231,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("https://example.com", maskedBankAccount.HostedVerificationUrl); await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, Arg.Is(options => + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == string.Empty && options.Metadata[MetadataKeys.RetiredBraintreeCustomerId] == "braintree_customer_id")); } @@ -262,7 +261,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -291,7 +290,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); } @@ -315,7 +314,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -344,10 +343,10 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Address.Country == "US" && options.Address.PostalCode == "12345")); } @@ -468,7 +467,7 @@ public class UpdatePaymentMethodCommandTests var maskedPayPalAccount = maskedPaymentMethod.AsT2; Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id")); } diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs index b6b0d596b3..4e4c5199e2 100644 --- a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Braintree; using Microsoft.Extensions.Logging; @@ -166,7 +165,7 @@ public class GetPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))).Returns( new SetupIntent { diff --git a/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs index c7ab0c17ff..9ade4d0979 100644 --- a/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using NSubstitute; using NSubstitute.ReturnsExtensions; @@ -57,7 +56,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { @@ -162,7 +161,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { @@ -246,7 +245,7 @@ public class HasPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))) .Returns(new SetupIntent { diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index 493246c578..b58b5cd250 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -53,7 +53,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests Available = true, LegacyYear = null, Seat = new PremiumPurchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, - Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal } + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } }; _pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan); @@ -146,11 +146,11 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSetupIntent = Substitute.For(); mockSetupIntent.Id = "seti_123"; - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); - _stripeAdapter.SetupIntentList(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -158,8 +158,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -200,10 +200,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -211,8 +211,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -243,10 +243,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); @@ -255,8 +255,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); @@ -299,10 +299,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -356,8 +356,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Mock that the user has a payment method (this is the key difference from the credit purchase case) _hasPaymentMethodQuery.Run(Arg.Any()).Returns(true); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -365,7 +365,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any(), Arg.Any(), Arg.Any()); } @@ -415,8 +415,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests _updatePaymentMethodCommand.Run(Arg.Any(), Arg.Any(), Arg.Any()) .Returns(mockMaskedPaymentMethod); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -428,9 +428,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Verify GetCustomerOrThrow was called after updating payment method await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); // Verify no new customer was created - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); // Verify subscription was created - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); // Verify user was updated correctly Assert.True(user.Premium); await _userService.Received(1).SaveUserAsync(user); @@ -474,10 +474,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -525,10 +525,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -577,10 +577,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -628,13 +628,13 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SetupIntentList(Arg.Any()) + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()) .Returns(Task.FromResult(new List())); // Empty list - no setup intent found // Act @@ -681,8 +681,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -690,7 +690,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); Assert.True(user.Premium); Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate); } @@ -716,8 +716,67 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests Assert.True(result.IsT3); // Assuming T3 is the Unhandled result Assert.IsType(result.AsT3.Exception); // Verify no customer was created or subscription attempted - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateSubscriptionAsync(Arg.Any()); await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); } + + [Theory, BitAutoData] + public async Task Run_WithAdditionalStorage_SetsCorrectMaxStorageGb( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + const short additionalStorage = 2; + + // Setup premium plan with 5GB provided storage + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } + }; + _pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan); + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, additionalStorage); + + // Assert + Assert.True(result.IsT0); + Assert.Equal((short)3, user.MaxStorageGb); // 1 (provided) + 2 (additional) = 3 + await _userService.Received(1).SaveUserAsync(user); + } + } diff --git a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs index d0b2eb7aa4..b5afaf65cd 100644 --- a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; @@ -50,7 +50,7 @@ public class PreviewPremiumTaxCommandTests Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -59,7 +59,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(3.00m, tax); Assert.Equal(33.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -84,7 +84,7 @@ public class PreviewPremiumTaxCommandTests Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(5, billingAddress); @@ -93,7 +93,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(5.00m, tax); Assert.Equal(55.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -120,7 +120,7 @@ public class PreviewPremiumTaxCommandTests Total = 2750 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -129,7 +129,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(2.50m, tax); Assert.Equal(27.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -154,7 +154,7 @@ public class PreviewPremiumTaxCommandTests Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(20, billingAddress); @@ -163,7 +163,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(8.00m, tax); Assert.Equal(88.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -190,7 +190,7 @@ public class PreviewPremiumTaxCommandTests Total = 4950 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(10, billingAddress); @@ -199,7 +199,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(4.50m, tax); Assert.Equal(49.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -226,7 +226,7 @@ public class PreviewPremiumTaxCommandTests Total = 3000 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -235,7 +235,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(0.00m, tax); Assert.Equal(30.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -260,7 +260,7 @@ public class PreviewPremiumTaxCommandTests Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(-5, billingAddress); @@ -269,7 +269,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(6.00m, tax); Assert.Equal(66.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "FR" && @@ -295,7 +295,7 @@ public class PreviewPremiumTaxCommandTests Total = 3123 // $31.23 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); diff --git a/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs b/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs new file mode 100644 index 0000000000..31547dffbe --- /dev/null +++ b/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs @@ -0,0 +1,234 @@ +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Billing.Premium.Queries; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Billing.Premium.Queries; + +[SutProviderCustomize] +public class HasPremiumAccessQueryTests +{ + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasPersonalPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; + user.OrganizationPremium = false; + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumButHasOrgPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = true; // Has org premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumAndNoOrgPremium_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = false; + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.False(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserNotFound_ThrowsNotFoundException( + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetPremiumAccessAsync(userId) + .Returns((UserPremiumAccess?)null); + + // Act & Assert + await Assert.ThrowsAsync( + () => sutProvider.Sut.HasPremiumAccessAsync(userId)); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasNoOrganizations_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = false; // No premium from anywhere + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.False(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasPremiumFromOrg_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; // No personal premium + user.OrganizationPremium = true; // But has premium from org + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasOnlyPersonalPremium_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; // Has personal premium + user.OrganizationPremium = false; // Not in any org that grants premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.False(result); // Should return false because user is not in an org that grants premium + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasBothPersonalAndOrgPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; // Has personal premium + user.OrganizationPremium = true; // Also in an org that grants premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.True(result); // Should return true because user IS in an org that grants premium (regardless of personal premium) + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserNotFound_ThrowsNotFoundException( + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetPremiumAccessAsync(userId) + .Returns((UserPremiumAccess?)null); + + // Act & Assert + await Assert.ThrowsAsync( + () => sutProvider.Sut.HasPremiumFromOrganizationAsync(userId)); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_Bulk_WhenEmptyList_ReturnsEmptyDictionary( + SutProvider sutProvider) + { + // Arrange + var userIds = new List(); + + sutProvider.GetDependency() + .GetPremiumAccessByIdsAsync(userIds) + .Returns(new List()); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds); + + // Assert + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_Bulk_ReturnsCorrectStatus( + UserPremiumAccess user1, + UserPremiumAccess user2, + UserPremiumAccess user3, + SutProvider sutProvider) + { + // Arrange + user1.PersonalPremium = true; + user1.OrganizationPremium = false; + user2.PersonalPremium = false; + user2.OrganizationPremium = false; + user3.PersonalPremium = false; + user3.OrganizationPremium = true; + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + sutProvider.GetDependency() + .GetPremiumAccessByIdsAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds); + + // Assert + Assert.Equal(3, result.Count); + Assert.True(result[user1.Id]); // Personal premium + Assert.False(result[user2.Id]); // No premium + Assert.True(result[user3.Id]); // Organization premium + } +} diff --git a/test/Core.Test/Billing/Pricing/PricingClientTests.cs b/test/Core.Test/Billing/Pricing/PricingClientTests.cs index 189df15b9c..43329e9c2e 100644 --- a/test/Core.Test/Billing/Pricing/PricingClientTests.cs +++ b/test/Core.Test/Billing/Pricing/PricingClientTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Services; -using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Logging; @@ -34,7 +33,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -70,7 +68,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -109,7 +106,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -144,7 +140,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -179,7 +174,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -217,7 +211,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -258,7 +251,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -297,7 +289,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -339,33 +330,12 @@ public class PricingClientTests Assert.Null(result); } - [Theory, BitAutoData] - public async Task GetPlan_WhenPricingServiceDisabled_ReturnsStaticStorePlan( - SutProvider sutProvider) - { - // Arrange - sutProvider.GetDependency().SelfHosted = false; - - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.UsePricingService) - .Returns(false); - - // Act - var result = await sutProvider.Sut.GetPlan(PlanType.FamiliesAnnually); - - // Assert - Assert.NotNull(result); - Assert.Equal(PlanType.FamiliesAnnually, result.Type); - } - [Theory, BitAutoData] public async Task GetPlan_WhenLookupKeyNotFound_ReturnsNull( SutProvider sutProvider) { // Arrange - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.UsePricingService) - .Returns(true); + sutProvider.GetDependency().SelfHosted = false; // Act - Using PlanType that doesn't have a lookup key mapping var result = await sutProvider.Sut.GetPlan(unchecked((PlanType)999)); @@ -384,7 +354,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -413,7 +382,6 @@ public class PricingClientTests var featureService = Substitute.For(); featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; @@ -450,26 +418,6 @@ public class PricingClientTests Assert.Empty(result); } - [Theory, BitAutoData] - public async Task ListPlans_WhenPricingServiceDisabled_ReturnsStaticStorePlans( - SutProvider sutProvider) - { - // Arrange - sutProvider.GetDependency().SelfHosted = false; - - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.UsePricingService) - .Returns(false); - - // Act - var result = await sutProvider.Sut.ListPlans(); - - // Assert - Assert.NotNull(result); - Assert.NotEmpty(result); - Assert.Equal(StaticStore.Plans.Count(), result.Count); - } - [Fact] public async Task ListPlans_WhenPricingServiceReturnsError_ThrowsBillingException() { @@ -479,7 +427,6 @@ public class PricingClientTests .Respond(HttpStatusCode.InternalServerError); var featureService = Substitute.For(); - featureService.IsEnabled(FeatureFlagKeys.UsePricingService).Returns(true); var globalSettings = new GlobalSettings { SelfHosted = false }; diff --git a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs index 40fa4c412d..f1b9446b6d 100644 --- a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs +++ b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs @@ -9,8 +9,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -31,10 +30,10 @@ public class OrganizationBillingServiceTests SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().ListPlans().Returns(StaticStore.Plans.ToList()); + sutProvider.GetDependency().ListPlans().Returns(MockPlans.Plans.ToList()); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var subscriberService = sutProvider.GetDependency(); var organizationSeatCount = new OrganizationSeatCounts { Users = 1, Sponsored = 0 }; @@ -97,10 +96,10 @@ public class OrganizationBillingServiceTests { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().ListPlans().Returns(StaticStore.Plans.ToList()); + sutProvider.GetDependency().ListPlans().Returns(MockPlans.Plans.ToList()); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) @@ -134,7 +133,7 @@ public class OrganizationBillingServiceTests SutProvider sutProvider) { // Arrange - var plan = StaticStore.GetPlan(PlanType.TeamsAnnually); + var plan = MockPlans.Get(PlanType.TeamsAnnually); organization.PlanType = PlanType.TeamsAnnually; organization.GatewayCustomerId = "cus_test123"; organization.GatewaySubscriptionId = null; @@ -178,7 +177,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -195,7 +194,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(7, capturedOptions.TrialPeriodDays); @@ -210,7 +209,7 @@ public class OrganizationBillingServiceTests SutProvider sutProvider) { // Arrange - var plan = StaticStore.GetPlan(PlanType.TeamsAnnually); + var plan = MockPlans.Get(PlanType.TeamsAnnually); organization.PlanType = PlanType.TeamsAnnually; organization.GatewayCustomerId = "cus_test123"; organization.GatewaySubscriptionId = null; @@ -254,7 +253,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -271,7 +270,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(0, capturedOptions.TrialPeriodDays); @@ -284,7 +283,7 @@ public class OrganizationBillingServiceTests SutProvider sutProvider) { // Arrange - var plan = StaticStore.GetPlan(PlanType.TeamsAnnually); + var plan = MockPlans.Get(PlanType.TeamsAnnually); organization.PlanType = PlanType.TeamsAnnually; organization.GatewayCustomerId = "cus_test123"; organization.GatewaySubscriptionId = null; @@ -328,7 +327,7 @@ public class OrganizationBillingServiceTests SubscriptionCreateOptions capturedOptions = null; sutProvider.GetDependency() - .SubscriptionCreateAsync(Arg.Do(options => capturedOptions = options)) + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) .Returns(new Subscription { Id = "sub_test123", @@ -345,7 +344,7 @@ public class OrganizationBillingServiceTests // Assert await sutProvider.GetDependency() .Received(1) - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); Assert.NotNull(capturedOptions); Assert.Equal(7, capturedOptions.TrialPeriodDays); @@ -353,4 +352,173 @@ public class OrganizationBillingServiceTests } #endregion + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_UpdatesStripeCustomer( + Organization organization, + SutProvider sutProvider) + { + organization.Name = "Short name"; + + CustomerUpdateOptions capturedOptions = null; + sutProvider.GetDependency() + .UpdateCustomerAsync( + Arg.Is(id => id == organization.GatewayCustomerId), + Arg.Do(options => capturedOptions = options)) + .Returns(new Customer()); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.Equal(organization.BillingEmail, capturedOptions.Email); + Assert.Equal(organization.DisplayName(), capturedOptions.Description); + Assert.NotNull(capturedOptions.InvoiceSettings); + Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields); + Assert.Single(capturedOptions.InvoiceSettings.CustomFields); + + var customField = capturedOptions.InvoiceSettings.CustomFields.First(); + Assert.Equal(organization.SubscriberType(), customField.Name); + Assert.Equal(organization.DisplayName(), customField.Value); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_UsesFullName( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var longName = "This is a very long organization name that exceeds thirty characters"; + organization.Name = longName; + + CustomerUpdateOptions capturedOptions = null; + sutProvider.GetDependency() + .UpdateCustomerAsync( + Arg.Is(id => id == organization.GatewayCustomerId), + Arg.Do(options => capturedOptions = options)) + .Returns(new Customer()); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.NotNull(capturedOptions.InvoiceSettings); + Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields); + + var customField = capturedOptions.InvoiceSettings.CustomFields.First(); + Assert.Equal(longName, customField.Value); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.GatewayCustomerId = null; + organization.Name = "Test Organization"; + organization.BillingEmail = "billing@example.com"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsEmpty_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.GatewayCustomerId = ""; + organization.Name = "Test Organization"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsNull_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = null; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsEmpty_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = ""; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenBillingEmailIsNull_UpdatesWithNull( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = "Test Organization"; + organization.BillingEmail = null; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Is(options => + options.Email == null && + options.Description == organization.Name)); + } } diff --git a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs index 06a408c5a8..cd4c5effbe 100644 --- a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs +++ b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs @@ -1,9 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -19,7 +19,7 @@ public class PaymentHistoryServiceTests var subscriber = new Organization { GatewayCustomerId = "cus_id", GatewaySubscriptionId = "sub_id" }; var invoices = new List { new() { Id = "in_id" } }; var stripeAdapter = Substitute.For(); - stripeAdapter.InvoiceListAsync(Arg.Any()).Returns(invoices); + stripeAdapter.ListInvoicesAsync(Arg.Any()).Returns(invoices); var transactionRepository = Substitute.For(); var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository); @@ -29,7 +29,7 @@ public class PaymentHistoryServiceTests // Assert Assert.NotEmpty(result); Assert.Single(result); - await stripeAdapter.Received(1).InvoiceListAsync(Arg.Any()); + await stripeAdapter.Received(1).ListInvoicesAsync(Arg.Any()); } [Fact] diff --git a/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs new file mode 100644 index 0000000000..73f28113ca --- /dev/null +++ b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs @@ -0,0 +1,411 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class StripePaymentServiceTests +{ + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithCustomerDiscount_ReturnsDiscountFromCustomer( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var customerDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 20m, + AmountOff = 1400 + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = customerDiscount + }, + Discounts = new List(), // Empty list + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + Assert.Equal(14.00m, result.CustomerDiscount.AmountOff); // Converted from cents + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithoutCustomerDiscount_FallsBackToSubscriptionDiscounts( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscriptionDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 15m, + AmountOff = null + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + Discounts = new List { subscriptionDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should use subscription discount as fallback + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(15m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithBothDiscounts_PrefersCustomerDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var customerDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 25m + }, + End = null + }; + + var subscriptionDiscount = new Discount + { + Coupon = new Coupon + { + Id = "different-coupon-id", + PercentOff = 10m + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = customerDiscount // Should prefer this + }, + Discounts = new List { subscriptionDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should prefer customer discount over subscription discount + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(25m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNoDiscounts_ReturnsNullDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null + }, + Discounts = new List(), // Empty list, no discounts + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithMultipleSubscriptionDiscounts_SelectsFirstDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Multiple subscription-level discounts, no customer discount + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var firstDiscount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-10-percent", + PercentOff = 10m + }, + End = null + }; + + var secondDiscount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-20-percent", + PercentOff = 20m + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + // Multiple subscription discounts - FirstOrDefault() should select the first one + Discounts = new List { firstDiscount, secondDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should select the first discount from the list (FirstOrDefault() behavior) + Assert.NotNull(result.CustomerDiscount); + Assert.Equal("coupon-10-percent", result.CustomerDiscount.Id); + Assert.Equal(10m, result.CustomerDiscount.PercentOff); + // Verify the second discount was not selected + Assert.NotEqual("coupon-20-percent", result.CustomerDiscount.Id); + Assert.NotEqual(20m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNullCustomer_HandlesGracefully( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Subscription with null Customer (defensive null check scenario) + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = null, // Customer not expanded or null + Discounts = new List(), // Empty discounts + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should handle null Customer gracefully without throwing NullReferenceException + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNullDiscounts_HandlesGracefully( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Subscription with null Discounts (defensive null check scenario) + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + Discounts = null, // Discounts not expanded or null + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should handle null Discounts gracefully without throwing NullReferenceException + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_VerifiesCorrectExpandOptions( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer { Discount = null }, + Discounts = new List(), // Empty list + Items = new StripeList { Data = [] } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .GetSubscriptionAsync( + Arg.Any(), + Arg.Any()) + .Returns(subscription); + + // Act + await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Verify expand options are correct + await stripeAdapter.Received(1).GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Is(o => + o.Expand.Contains("customer.discount.coupon.applies_to") && + o.Expand.Contains("discounts.coupon.applies_to") && + o.Expand.Contains("test_clock"))); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithEmptyGatewaySubscriptionId_ReturnsEmptySubscriptionInfo( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.GatewaySubscriptionId = null; + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.NotNull(result); + Assert.Null(result.Subscription); + Assert.Null(result.CustomerDiscount); + Assert.Null(result.UpcomingInvoice); + + // Verify no Stripe API calls were made + await sutProvider.GetDependency() + .DidNotReceive() + .GetSubscriptionAsync(Arg.Any(), Arg.Any()); + } +} diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 2569ffff00..2f938065e5 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,10 +3,10 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Tax.Models; using Bit.Core.Enums; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -44,7 +44,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); await ThrowsBillingExceptionAsync(() => @@ -52,11 +52,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -81,7 +81,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -95,12 +95,12 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is( + .UpdateSubscriptionAsync(subscriptionId, Arg.Is( options => options.Metadata["cancellingUserId"] == userId.ToString())); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -127,7 +127,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -141,11 +141,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -170,7 +170,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -184,7 +184,7 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is(options => + .UpdateSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancelAtPeriodEnd == true && options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason && @@ -192,7 +192,7 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -223,7 +223,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -237,7 +237,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -253,7 +253,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomer(organization); @@ -287,7 +287,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); @@ -301,7 +301,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -318,7 +318,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); @@ -328,157 +328,6 @@ public class SubscriberServiceTests #endregion - #region GetPaymentMethod - - [Theory, BitAutoData] - public async Task GetPaymentMethod_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GetPaymentSource(null)); - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithNegativeStripeAccountBalance_ReturnsCorrectAccountCreditAmount(Organization organization, - SutProvider sutProvider) - { - // Arrange - // Stripe reports balance in cents as a negative number for credit - const int stripeAccountBalance = -593; // $5.93 credit (negative cents) - const decimal creditAmount = 5.93M; // Same value in dollars - - - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(creditAmount, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - - } - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithZeroStripeAccountBalance_ReturnsCorrectAccountCreditAmount( - Organization organization, SutProvider sutProvider) - { - // Arrange - const int stripeAccountBalance = 0; - - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(0, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - } - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithPositiveStripeAccountBalance_ReturnsCorrectAccountCreditAmount( - Organization organization, SutProvider sutProvider) - { - // Arrange - const int stripeAccountBalance = 593; // $5.93 charge balance - const decimal accountBalance = -5.93M; // account balance - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(accountBalance, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - - } - #endregion - #region GetPaymentSource [Theory, BitAutoData] @@ -502,7 +351,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -539,7 +388,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -593,7 +442,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -629,7 +478,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -649,7 +498,7 @@ public class SubscriberServiceTests { var customer = new Customer { Id = provider.GatewayCustomerId }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -672,7 +521,7 @@ public class SubscriberServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntent.Id); - sutProvider.GetDependency().SetupIntentGet(setupIntent.Id, + sutProvider.GetDependency().GetSetupIntentAsync(setupIntent.Id, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(setupIntent); var paymentMethod = await sutProvider.Sut.GetPaymentSource(provider); @@ -692,7 +541,7 @@ public class SubscriberServiceTests DefaultSource = new BankAccount { Status = "verified", BankName = "Chase", Last4 = "9999" } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -715,7 +564,7 @@ public class SubscriberServiceTests DefaultSource = new Card { Brand = "Visa", Last4 = "9999", ExpMonth = 9, ExpYear = 2028 } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -747,7 +596,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -787,7 +636,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -801,7 +650,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -817,7 +666,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscription(organization); @@ -849,7 +698,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); @@ -863,7 +712,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -880,7 +729,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); @@ -889,65 +738,6 @@ public class SubscriberServiceTests } #endregion - #region GetTaxInformation - - [Theory, BitAutoData] - public async Task GetTaxInformation_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GetTaxInformation(null)); - - [Theory, BitAutoData] - public async Task GetTaxInformation_NullAddress_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(new Customer()); - - var taxInformation = await sutProvider.Sut.GetTaxInformation(organization); - - Assert.Null(taxInformation); - } - - [Theory, BitAutoData] - public async Task GetTaxInformation_Success( - Organization organization, - SutProvider sutProvider) - { - var address = new Address - { - Country = "US", - PostalCode = "12345", - Line1 = "123 Example St.", - Line2 = "Unit 1", - City = "Example Town", - State = "NY" - }; - - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(new Customer - { - Address = address, - TaxIds = new StripeList - { - Data = [new TaxId { Value = "tax_id" }] - } - }); - - var taxInformation = await sutProvider.Sut.GetTaxInformation(organization); - - Assert.NotNull(taxInformation); - Assert.Equal(address.Country, taxInformation.Country); - Assert.Equal(address.PostalCode, taxInformation.PostalCode); - Assert.Equal("tax_id", taxInformation.TaxId); - Assert.Equal(address.Line1, taxInformation.Line1); - Assert.Equal(address.Line2, taxInformation.Line2); - Assert.Equal(address.City, taxInformation.City); - Assert.Equal(address.State, taxInformation.State); - } - - #endregion - #region RemovePaymentMethod [Theory, BitAutoData] public async Task RemovePaymentMethod_NullSubscriber_ThrowsArgumentNullException( @@ -970,7 +760,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (braintreeGateway, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1005,7 +795,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1042,7 +832,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1097,7 +887,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1156,21 +946,21 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List())); await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.Received(1).BankAccountDeleteAsync(stripeCustomer.Id, bankAccountId); + await stripeAdapter.Received(1).DeleteBankAccountAsync(stripeCustomer.Id, bankAccountId); - await stripeAdapter.Received(1).CardDeleteAsync(stripeCustomer.Id, cardId); + await stripeAdapter.Received(1).DeleteCardAsync(stripeCustomer.Id, cardId); await stripeAdapter.DidNotReceiveWithAnyArgs() - .PaymentMethodDetachAsync(Arg.Any(), Arg.Any()); + .DetachPaymentMethodAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -1188,11 +978,11 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List { new () @@ -1207,15 +997,15 @@ public class SubscriberServiceTests await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.DidNotReceiveWithAnyArgs().BankAccountDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteBankAccountAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.DidNotReceiveWithAnyArgs().CardDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteCardAsync(Arg.Any(), Arg.Any()); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(bankAccountId); + .DetachPaymentMethodAsync(bankAccountId); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(cardId); + .DetachPaymentMethodAsync(cardId); } private static async IAsyncEnumerable GetPaymentMethodsAsync( @@ -1260,7 +1050,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1272,7 +1062,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1286,10 +1076,10 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([new SetupIntent(), new SetupIntent()]); await ThrowsBillingExceptionAsync(() => @@ -1303,7 +1093,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1317,10 +1107,10 @@ public class SubscriberServiceTests var matchingSetupIntent = new SetupIntent { Id = "setup_intent_1" }; - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([matchingSetupIntent]); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); @@ -1329,12 +1119,12 @@ public class SubscriberServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_1"); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); } @@ -1345,7 +1135,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")) ) @@ -1358,22 +1148,22 @@ public class SubscriberServiceTests } }); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.Card, "TOKEN")); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).PaymentMethodAttachAsync("TOKEN", Arg.Is( + await stripeAdapter.Received(1).AttachPaymentMethodAsync("TOKEN", Arg.Is( options => options.Customer == provider.GatewayCustomerId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.InvoiceSettings.DefaultPaymentMethod == "TOKEN" && options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); @@ -1386,7 +1176,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1412,7 +1202,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1450,7 +1240,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1504,7 +1294,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1573,7 +1363,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId @@ -1605,7 +1395,7 @@ public class SubscriberServiceTests new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN"))); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -1615,7 +1405,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1652,7 +1442,7 @@ public class SubscriberServiceTests await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN")); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == braintreeCustomerId)); } @@ -1683,7 +1473,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1697,7 +1487,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "US" && @@ -1732,12 +1522,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1746,13 +1536,13 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -1765,7 +1555,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1779,7 +1569,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "CA" && @@ -1815,12 +1605,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1829,63 +1619,21 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } #endregion - #region VerifyBankAccount - - [Theory, BitAutoData] - public async Task VerifyBankAccount_NoSetupIntentId_ThrowsBillingException( - Provider provider, - SutProvider sutProvider) => await ThrowsBillingExceptionAsync(() => sutProvider.Sut.VerifyBankAccount(provider, "")); - - [Theory, BitAutoData] - public async Task VerifyBankAccount_MakesCorrectInvocations( - Provider provider, - SutProvider sutProvider) - { - const string descriptorCode = "SM1234"; - - var setupIntent = new SetupIntent - { - Id = "setup_intent_id", - PaymentMethodId = "payment_method_id" - }; - - sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntent.Id); - - var stripeAdapter = sutProvider.GetDependency(); - - stripeAdapter.SetupIntentGet(setupIntent.Id).Returns(setupIntent); - - await sutProvider.Sut.VerifyBankAccount(provider, descriptorCode); - - await stripeAdapter.Received(1).SetupIntentVerifyMicroDeposit(setupIntent.Id, - Arg.Is( - options => options.DescriptorCode == descriptorCode)); - - await stripeAdapter.Received(1).PaymentMethodAttachAsync(setupIntent.PaymentMethodId, - Arg.Is( - options => options.Customer == provider.GatewayCustomerId)); - - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( - options => options.InvoiceSettings.DefaultPaymentMethod == setupIntent.PaymentMethodId)); - } - - #endregion - #region IsValidGatewayCustomerIdAsync [Theory, BitAutoData] @@ -1907,7 +1655,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1921,7 +1669,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1930,12 +1678,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer()); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Returns(new Customer()); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } [Theory, BitAutoData] @@ -1945,12 +1693,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } #endregion @@ -1976,7 +1724,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1990,7 +1738,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1999,12 +1747,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } [Theory, BitAutoData] @@ -2014,12 +1762,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } #endregion diff --git a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs index 570f94575f..41f8839eb4 100644 --- a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs @@ -6,7 +6,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Entities; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -98,13 +97,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(organization).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(organization); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is((SubscriptionCreateOptions options) => + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is((SubscriptionCreateOptions options) => options.AutomaticTax.Enabled == true && options.CollectionMethod == CollectionMethod.ChargeAutomatically && options.Customer == "cus_123" && @@ -154,13 +153,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(provider).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(provider); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _providerRepository.Received(1).ReplaceAsync(Arg.Is(prov => prov.Id == providerId && @@ -199,13 +198,13 @@ public class RestartSubscriptionCommandTests }; _subscriberService.GetSubscription(user).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(user); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => u.Id == userId && diff --git a/test/Core.Test/Context/CurrentContextTests.cs b/test/Core.Test/Context/CurrentContextTests.cs index b868d6ceaa..41a54a5b22 100644 --- a/test/Core.Test/Context/CurrentContextTests.cs +++ b/test/Core.Test/Context/CurrentContextTests.cs @@ -107,30 +107,6 @@ public class CurrentContextTests Assert.Equal(deviceType, sutProvider.Sut.DeviceType); } - [Theory, BitAutoData] - public async Task BuildAsync_HttpContext_SetsCloudflareFlags( - SutProvider sutProvider) - { - var httpContext = new DefaultHttpContext(); - var globalSettings = new Core.Settings.GlobalSettings(); - sutProvider.Sut.BotScore = null; - // Arrange - var botScore = 85; - httpContext.Request.Headers["X-Cf-Bot-Score"] = botScore.ToString(); - httpContext.Request.Headers["X-Cf-Worked-Proxied"] = "1"; - httpContext.Request.Headers["X-Cf-Is-Bot"] = "1"; - httpContext.Request.Headers["X-Cf-Maybe-Bot"] = "1"; - - // Act - await sutProvider.Sut.BuildAsync(httpContext, globalSettings); - - // Assert - Assert.True(sutProvider.Sut.CloudflareWorkerProxied); - Assert.True(sutProvider.Sut.IsBot); - Assert.True(sutProvider.Sut.MaybeBot); - Assert.Equal(botScore, sutProvider.Sut.BotScore); - } - [Theory, BitAutoData] public async Task BuildAsync_HttpContext_SetsClientVersion( SutProvider sutProvider) diff --git a/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs b/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs new file mode 100644 index 0000000000..612d63f289 --- /dev/null +++ b/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs @@ -0,0 +1,86 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Queries; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.KeyManagement.Queries; + +[SutProviderCustomize] +public class KeyConnectorConfirmationDetailsQueryTests +{ + [Theory] + [BitAutoData] + public async Task Run_OrganizationNotFound_Throws(SutProvider sutProvider, + Guid userId, string orgSsoIdentifier) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .ReceivedWithAnyArgs(0) + .GetByOrganizationAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task Run_OrganizationNotKeyConnector_Throws( + SutProvider sutProvider, + Guid userId, string orgSsoIdentifier, Organization org) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = false; + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .ReceivedWithAnyArgs(0) + .GetByOrganizationAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task Run_OrganizationUserNotFound_Throws(SutProvider sutProvider, + Guid userId, string orgSsoIdentifier + , Organization org) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = true; + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + sutProvider.GetDependency() + .GetByOrganizationAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(null)); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .Received(1) + .GetByOrganizationAsync(org.Id, userId); + } + + [Theory] + [BitAutoData] + public async Task Run_Success(SutProvider sutProvider, Guid userId, + string orgSsoIdentifier + , Organization org, OrganizationUser orgUser) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = true; + orgUser.OrganizationId = org.Id; + orgUser.UserId = userId; + + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + sutProvider.GetDependency().GetByOrganizationAsync(org.Id, userId) + .Returns(orgUser); + + var result = await sutProvider.Sut.Run(orgSsoIdentifier, userId); + + Assert.Equal(org.Name, result.OrganizationName); + await sutProvider.GetDependency() + .Received(1) + .GetByOrganizationAsync(org.Id, userId); + } +} diff --git a/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs index dee805033a..39374755eb 100644 --- a/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs @@ -2,7 +2,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -17,7 +17,7 @@ public class CompleteSubscriptionUpdateTests public void UpgradeItemOptions_TeamsStarterToTeams_ReturnsCorrectOptions( Organization organization) { - var teamsStarterPlan = StaticStore.GetPlan(PlanType.TeamsStarter); + var teamsStarterPlan = MockPlans.Get(PlanType.TeamsStarter); var subscription = new Subscription { @@ -35,7 +35,7 @@ public class CompleteSubscriptionUpdateTests } }; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -66,7 +66,7 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -102,7 +102,7 @@ public class CompleteSubscriptionUpdateTests } }; - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -173,7 +173,7 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -209,7 +209,7 @@ public class CompleteSubscriptionUpdateTests } }; - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -277,8 +277,8 @@ public class CompleteSubscriptionUpdateTests public void RevertItemOptions_TeamsStarterToTeams_ReturnsCorrectOptions( Organization organization) { - var teamsStarterPlan = StaticStore.GetPlan(PlanType.TeamsStarter); - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsStarterPlan = MockPlans.Get(PlanType.TeamsStarter); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -325,8 +325,8 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var subscription = new Subscription { @@ -431,8 +431,8 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var subscription = new Subscription { diff --git a/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs index b6e9f63640..d96f9fea95 100644 --- a/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class SeatSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { @@ -69,7 +69,7 @@ public class SeatSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { diff --git a/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs index 20405b07b0..1f75b6a23a 100644 --- a/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs @@ -4,7 +4,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; @@ -16,7 +16,7 @@ public class SecretsManagerSubscriptionUpdateTests private static TheoryData ToPlanTheory(List types) { var theoryData = new TheoryData(); - var plans = types.Select(StaticStore.GetPlan).ToArray(); + var plans = types.Select(MockPlans.Get).ToArray(); theoryData.AddRange(plans); return theoryData; } diff --git a/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs index 3663277933..a1e9669c87 100644 --- a/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class ServiceAccountSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { @@ -69,7 +69,7 @@ public class ServiceAccountSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 5; var subscription = new Subscription diff --git a/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs index ee9dc615b6..d9fcaf991e 100644 --- a/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class SmSeatSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 3; var subscription = new Subscription @@ -70,7 +70,7 @@ public class SmSeatSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 5; var subscription = new Subscription diff --git a/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs index 79b29fcd0c..21326c5324 100644 --- a/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs @@ -1,6 +1,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -26,7 +26,7 @@ public class StorageSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var subscription = new Subscription { Items = new StripeList @@ -77,7 +77,7 @@ public class StorageSubscriptionUpdateTests [BitAutoData(PlanType.TeamsStarter)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var subscription = new Subscription { Items = new StripeList diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs index 786a6f6c0d..a6db6ae8fd 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,7 +13,7 @@ public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseT protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, OrganizationSponsorship sponsorship, SutProvider sutProvider) { - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); if (sponsorship != null) @@ -46,7 +47,7 @@ OrganizationSponsorship sponsorship, SutProvider sutProvider) protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .RemoveOrganizationSponsorshipAsync(default, default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .UpsertAsync(default); diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs index 69e7183c65..127cc7e502 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs @@ -1,10 +1,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -82,7 +82,7 @@ public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() .SponsorOrganizationAsync(default, default); await sutProvider.GetDependency() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs index 5feee0f13a..515b4d7ba1 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs @@ -1,22 +1,22 @@ using Bit.Core.Billing.Enums; using Bit.Core.Enums; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; public abstract class FamiliesForEnterpriseTestsBase { public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable FamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonConfirmedOrganizationUsersStatuses => Enum.GetValues() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs index 02ae40798b..83e1487b01 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs @@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -42,7 +43,7 @@ public class AddSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(plan); await sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts); @@ -54,7 +55,7 @@ public class AddSecretsManagerSubscriptionCommandTests c.AdditionalServiceAccounts == additionalServiceAccounts && c.AdditionalSeats == organization.Seats.GetValueOrDefault())); - await sutProvider.GetDependency().Received() + await sutProvider.GetDependency().Received() .AddSecretsManagerToSubscription(organization, plan, additionalSmSeats, additionalServiceAccounts); // TODO: call ReferenceEventService - see AC-1481 @@ -88,7 +89,7 @@ public class AddSecretsManagerSubscriptionCommandTests organization.GatewayCustomerId = null; organization.PlanType = PlanType.EnterpriseAnnually; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts)); Assert.Contains("No payment method found.", exception.Message); @@ -106,7 +107,7 @@ public class AddSecretsManagerSubscriptionCommandTests organization.GatewaySubscriptionId = null; organization.PlanType = PlanType.EnterpriseAnnually; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts)); Assert.Contains("No subscription found.", exception.Message); @@ -139,7 +140,7 @@ public class AddSecretsManagerSubscriptionCommandTests provider.Type = ProviderType.Msp; sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpAsync(organization, 10, 10)); @@ -150,7 +151,7 @@ public class AddSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AddSecretsManagerToSubscription(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs index 1e764de6d7..510433a2fa 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -11,7 +12,7 @@ using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -26,7 +27,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests private static TheoryData ToPlanTheory(List types) { var theoryData = new TheoryData(); - var plans = types.Select(StaticStore.GetPlan).ToArray(); + var plans = types.Select(MockPlans.Get).ToArray(); theoryData.AddRange(plans); return theoryData; } @@ -86,9 +87,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -136,9 +137,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -164,7 +165,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, autoscaling).AdjustSeats(2); sutProvider.GetDependency().SelfHosted.Returns(true); @@ -180,7 +181,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider, Organization organization) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); organization.UseSecretsManager = false; var update = new SecretsManagerSubscriptionUpdate(organization, plan, false); @@ -258,7 +259,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( + await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( Arg.Is(o => o.Id == organizationId), plan, expectedSmServiceAccountsExcludingBase); @@ -289,7 +290,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.MaxAutoscaleSmSeats = maxSeatCount; organization.PlanType = PlanType.EnterpriseAnnually; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { @@ -334,7 +335,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests var ownerDetailsList = new List { new() { Email = "owner@example.com" } }; organization.PlanType = PlanType.EnterpriseAnnually; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { @@ -372,7 +373,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmSeats = null; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustSeats(1); var exception = await Assert.ThrowsAsync( @@ -388,7 +389,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustSeats(-2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -404,7 +405,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustSeats(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -422,7 +423,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmSeats = 9; organization.MaxAutoscaleSmSeats = 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustSeats(2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -436,7 +437,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = organization.SmSeats + 10, @@ -455,7 +456,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = 0, @@ -475,7 +476,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmSeats = 8; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = 7, @@ -498,7 +499,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests var smServiceAccounts = 300; var existingServiceAccountCount = 299; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccounts, @@ -531,7 +532,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { var smServiceAccounts = 300; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccounts, @@ -571,7 +572,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmServiceAccounts = null; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustServiceAccounts(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -585,7 +586,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustServiceAccounts(-2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -601,7 +602,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustServiceAccounts(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -619,7 +620,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = 9; organization.MaxAutoscaleSmServiceAccounts = 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustServiceAccounts(2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -639,7 +640,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = smServiceAccount - 5; organization.MaxAutoscaleSmServiceAccounts = 2 * smServiceAccount; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccount, @@ -662,7 +663,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = newSmServiceAccounts - 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = newSmServiceAccounts, @@ -707,7 +708,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmSeats = smSeats - 1; organization.MaxAutoscaleSmSeats = smSeats * 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = smSeats, @@ -728,7 +729,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; organization.SmSeats = 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmSeats = 3 @@ -748,7 +749,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; organization.SmSeats = 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmSeats = 2 @@ -769,7 +770,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.PlanType = planType; organization.SmServiceAccounts = 3; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmServiceAccounts = 3 }; var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -779,9 +780,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustSmSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustServiceAccountsAsync(Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 await sutProvider.GetDependency().DidNotReceive() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs index 704f89ba3f..8a00604bb0 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -8,7 +9,7 @@ using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -45,7 +46,7 @@ public class UpgradeOrganizationPlanCommandTests SutProvider sutProvider) { upgrade.Plan = organization.PlanType; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); @@ -61,7 +62,7 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSmSeats = 10; upgrade.AdditionalServiceAccounts = 10; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); Assert.Contains("already on this plan", exception.Message); @@ -73,11 +74,11 @@ public class UpgradeOrganizationPlanCommandTests SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.AdditionalSmSeats = 10; upgrade.AdditionalSeats = 10; upgrade.Plan = PlanType.TeamsAnnually; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -104,7 +105,7 @@ public class UpgradeOrganizationPlanCommandTests organization.PlanType = PlanType.FamiliesAnnually; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); organizationUpgrade.AdditionalSeats = 30; organizationUpgrade.UseSecretsManager = true; @@ -113,7 +114,7 @@ public class UpgradeOrganizationPlanCommandTests organizationUpgrade.AdditionalStorageGb = 3; organizationUpgrade.Plan = planType; - sutProvider.GetDependency().GetPlanOrThrow(organizationUpgrade.Plan).Returns(StaticStore.GetPlan(organizationUpgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(organizationUpgrade.Plan).Returns(MockPlans.Get(organizationUpgrade.Plan)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -121,9 +122,9 @@ public class UpgradeOrganizationPlanCommandTests Users = 1 }); await sutProvider.Sut.UpgradePlanAsync(organization.Id, organizationUpgrade); - await sutProvider.GetDependency().Received(1).AdjustSubscription( + await sutProvider.GetDependency().Received(1).AdjustSubscription( organization, - StaticStore.GetPlan(planType), + MockPlans.Get(planType), organizationUpgrade.AdditionalSeats, organizationUpgrade.UseSecretsManager, organizationUpgrade.AdditionalSmSeats, @@ -141,12 +142,12 @@ public class UpgradeOrganizationPlanCommandTests public async Task UpgradePlan_SM_Passes(PlanType planType, Organization organization, OrganizationUpgrade upgrade, SutProvider sutProvider) { - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.Plan = planType; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); - var plan = StaticStore.GetPlan(upgrade.Plan); + var plan = MockPlans.Get(upgrade.Plan); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); @@ -184,10 +185,10 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSeats = 15; upgrade.AdditionalSmSeats = 1; upgrade.AdditionalServiceAccounts = 0; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); organization.SmSeats = 2; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() @@ -218,11 +219,11 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSeats = 15; upgrade.AdditionalSmSeats = 1; upgrade.AdditionalServiceAccounts = 0; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); organization.SmSeats = 1; organization.SmServiceAccounts = currentServiceAccounts; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index d624bebf51..b98c4580f5 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -268,4 +268,115 @@ public class HandlebarsMailServiceTests // Assert await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any()); } + + [Fact] + public async Task SendIndividualUserWelcomeEmailAsync_SendsCorrectEmail() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com" + }; + + // Act + await _sut.SendIndividualUserWelcomeEmailAsync(user); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("test@example.com") && + m.Subject == "Welcome to Bitwarden!" && + m.Category == "Welcome")); + } + + [Fact] + public async Task SendOrganizationUserWelcomeEmailAsync_SendsCorrectEmailWithOrganizationName() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "user@company.com" + }; + var organizationName = "Bitwarden Corp"; + + // Act + await _sut.SendOrganizationUserWelcomeEmailAsync(user, organizationName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("user@company.com") && + m.Subject == "Welcome to Bitwarden!" && + m.HtmlContent.Contains("Bitwarden Corp") && + m.Category == "Welcome")); + } + + [Fact] + public async Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync_SendsCorrectEmailWithFamilyTemplate() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "family@example.com" + }; + var familyOrganizationName = "Smith Family"; + + // Act + await _sut.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, familyOrganizationName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("family@example.com") && + m.Subject == "Welcome to Bitwarden!" && + m.HtmlContent.Contains("Smith Family") && + m.Category == "Welcome")); + } + + [Theory] + [InlineData("Acme Corp", "Acme Corp")] + [InlineData("Company & Associates", "Company & Associates")] + [InlineData("Test \"Quoted\" Org", "Test "Quoted" Org")] + public async Task SendOrganizationUserWelcomeEmailAsync_SanitizesOrganizationNameForEmail(string inputOrgName, string expectedSanitized) + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com" + }; + + // Act + await _sut.SendOrganizationUserWelcomeEmailAsync(user, inputOrgName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.HtmlContent.Contains(expectedSanitized) && + !m.HtmlContent.Contains("